Skip to content

Commit 1df4b5f

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas] Do not skip vmap tests on GPU when x64 is enabled
PiperOrigin-RevId: 698351984
1 parent 04e4c69 commit 1df4b5f

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

tests/pallas/pallas_vmap_test.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax
2323
from jax import random
2424
from jax._src import config
25+
from jax._src import dtypes
2526
from jax._src import test_util as jtu
2627
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
2728
from jax.experimental import pallas as pl
@@ -35,15 +36,17 @@
3536
config.parse_flags_with_absl()
3637

3738

39+
intx = dtypes.canonicalize_dtype(jnp.int64)
40+
floatx = dtypes.canonicalize_dtype(jnp.float64)
41+
42+
3843
@jtu.with_config(jax_traceback_filtering="off")
3944
class PallasBaseTest(jtu.JaxTestCase):
4045
INTERPRET = False
4146

4247
def setUp(self):
4348
if jtu.test_device_matches(["cpu"]) and not self.INTERPRET:
4449
self.skipTest("On CPU the test works only in interpret mode")
45-
if jtu.test_device_matches(["gpu"]) and jax.config.x64_enabled:
46-
self.skipTest("On GPU the test works only in 32-bit")
4750
if (jtu.test_device_matches(["cuda"]) and
4851
not jtu.is_cuda_compute_capability_at_least("8.0")):
4952
self.skipTest("Only works on GPU with capability >= sm80")
@@ -67,7 +70,7 @@ def setUp(self):
6770

6871
def test_vmap_of_simple_kernel(self):
6972
@functools.partial(
70-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
73+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx),
7174
)
7275
def add_one(x_ref, o_ref):
7376
o_ref[()] = x_ref[()] + 1
@@ -77,7 +80,7 @@ def add_one(x_ref, o_ref):
7780

7881
def test_vmap_of_simple_kernel_with_in_axes_None(self):
7982
@functools.partial(
80-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
83+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx),
8184
)
8285
def add(x_ref, y_ref, o_ref):
8386
o_ref[()] = x_ref[()] + y_ref[()]
@@ -87,7 +90,7 @@ def add(x_ref, y_ref, o_ref):
8790

8891
def test_double_vmap_of_simple_kernel(self):
8992
@functools.partial(
90-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
93+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx),
9194
)
9295
def add_one(x_ref, o_ref):
9396
o_ref[()] = x_ref[()] + 1
@@ -97,7 +100,7 @@ def add_one(x_ref, o_ref):
97100

98101
def test_quadruple_vmap_of_simple_kernel(self):
99102
@functools.partial(
100-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
103+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx),
101104
)
102105
def add_one(x_ref, o_ref):
103106
o_ref[()] = x_ref[()] + 1
@@ -108,7 +111,7 @@ def add_one(x_ref, o_ref):
108111

109112
def test_quadruple_vmap_of_batched_kernel(self):
110113
@functools.partial(
111-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), jnp.int32),
114+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((7,), intx),
112115
grid=(7,))
113116
def add_one(x_ref, o_ref):
114117
i = pl.program_id(0)
@@ -120,7 +123,7 @@ def add_one(x_ref, o_ref):
120123

121124
def test_vmap_of_slicing_kernel(self):
122125
@functools.partial(
123-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
126+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx),
124127
grid=(2,))
125128
def add_one(x_ref, o_ref):
126129
i = pl.program_id(0)
@@ -151,7 +154,7 @@ def kernel(src, dst):
151154

152155
def test_vmap_of_kernel_with_input_output_aliases(self):
153156
@functools.partial(
154-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
157+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), intx),
155158
input_output_aliases={1:0},
156159
grid=())
157160
def add(x_ref, _, o_ref):
@@ -163,7 +166,7 @@ def add(x_ref, _, o_ref):
163166
def test_vmap_of_kernel_with_input_output_aliases_different_axes(self):
164167
@functools.partial(
165168
self.pallas_call,
166-
out_shape=jax.ShapeDtypeStruct((4,), jnp.int32),
169+
out_shape=jax.ShapeDtypeStruct((4,), intx),
167170
input_output_aliases={0: 0},
168171
grid=(),
169172
)
@@ -176,7 +179,7 @@ def add(x_ref, o_ref):
176179

177180
def test_vmap_of_slicing_kernel_different_axes(self):
178181
@functools.partial(
179-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
182+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx),
180183
grid=(2,))
181184
def add_one(x_ref, o_ref):
182185
i = pl.program_id(0)
@@ -194,7 +197,7 @@ def add_one(x_ref, o_ref):
194197

195198
def test_double_vmap_of_slicing_kernel_different_axes(self):
196199
@functools.partial(
197-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32),
200+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), floatx),
198201
grid=(4,))
199202
def sin(x_ref, o_ref):
200203
i = pl.program_id(0)
@@ -211,7 +214,7 @@ def sin(x_ref, o_ref):
211214
def test_small_large_vmap(self):
212215
# Catches https://github.com/jax-ml/jax/issues/18361
213216
@functools.partial(
214-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
217+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx),
215218
grid=(2,))
216219
def add_one(x_ref, o_ref):
217220
o_ref[()] = x_ref[()] + 1
@@ -230,7 +233,7 @@ def add_one(x_ref, o_ref):
230233
def test_small_small_large_vmap(self):
231234

232235
@functools.partial(
233-
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32),
236+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), intx),
234237
grid=(2,))
235238
def add_one(x_ref, o_ref):
236239
o_ref[()] = x_ref[()] + 1
@@ -249,12 +252,6 @@ def add_one(x_ref, o_ref):
249252
class PallasCallVmapInterpretTest(PallasCallVmapTest):
250253
INTERPRET = True
251254

252-
def setUp(self):
253-
super().setUp()
254-
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
255-
# TODO: assertion failures on CPU in 64-bit mode
256-
self.skipTest("On CPU the test works only in 32-bit mode")
257-
258255

259256
if __name__ == "__main__":
260257
absltest.main()

0 commit comments

Comments
 (0)