2222import jax
2323from jax import random
2424from jax ._src import config
25+ from jax ._src import dtypes
2526from jax ._src import test_util as jtu
2627from jax ._src .pallas .pallas_call import _trace_kernel_to_jaxpr
2728from jax .experimental import pallas as pl
3536config .parse_flags_with_absl ()
3637
3738
39+ intx = dtypes .canonicalize_dtype (jnp .int64 )
40+ floatx = dtypes .canonicalize_dtype (jnp .float64 )
41+
42+
3843@jtu .with_config (jax_traceback_filtering = "off" )
3944class PallasBaseTest (jtu .JaxTestCase ):
4045 INTERPRET = False
4146
4247 def setUp (self ):
4348 if jtu .test_device_matches (["cpu" ]) and not self .INTERPRET :
4449 self .skipTest ("On CPU the test works only in interpret mode" )
45- if jtu .test_device_matches (["gpu" ]) and jax .config .x64_enabled :
46- self .skipTest ("On GPU the test works only in 32-bit" )
4750 if (jtu .test_device_matches (["cuda" ]) and
4851 not jtu .is_cuda_compute_capability_at_least ("8.0" )):
4952 self .skipTest ("Only works on GPU with capability >= sm80" )
@@ -67,7 +70,7 @@ def setUp(self):
6770
6871 def test_vmap_of_simple_kernel (self ):
6972 @functools .partial (
70- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), jnp . int32 ),
73+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), intx ),
7174 )
7275 def add_one (x_ref , o_ref ):
7376 o_ref [()] = x_ref [()] + 1
@@ -77,7 +80,7 @@ def add_one(x_ref, o_ref):
7780
7881 def test_vmap_of_simple_kernel_with_in_axes_None (self ):
7982 @functools .partial (
80- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), jnp . int32 ),
83+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), intx ),
8184 )
8285 def add (x_ref , y_ref , o_ref ):
8386 o_ref [()] = x_ref [()] + y_ref [()]
@@ -87,7 +90,7 @@ def add(x_ref, y_ref, o_ref):
8790
8891 def test_double_vmap_of_simple_kernel (self ):
8992 @functools .partial (
90- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), jnp . int32 ),
93+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), intx ),
9194 )
9295 def add_one (x_ref , o_ref ):
9396 o_ref [()] = x_ref [()] + 1
@@ -97,7 +100,7 @@ def add_one(x_ref, o_ref):
97100
98101 def test_quadruple_vmap_of_simple_kernel (self ):
99102 @functools .partial (
100- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), jnp . int32 ),
103+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), intx ),
101104 )
102105 def add_one (x_ref , o_ref ):
103106 o_ref [()] = x_ref [()] + 1
@@ -108,7 +111,7 @@ def add_one(x_ref, o_ref):
108111
109112 def test_quadruple_vmap_of_batched_kernel (self ):
110113 @functools .partial (
111- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((7 ,), jnp . int32 ),
114+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((7 ,), intx ),
112115 grid = (7 ,))
113116 def add_one (x_ref , o_ref ):
114117 i = pl .program_id (0 )
@@ -120,7 +123,7 @@ def add_one(x_ref, o_ref):
120123
121124 def test_vmap_of_slicing_kernel (self ):
122125 @functools .partial (
123- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), jnp . int32 ),
126+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), intx ),
124127 grid = (2 ,))
125128 def add_one (x_ref , o_ref ):
126129 i = pl .program_id (0 )
@@ -151,7 +154,7 @@ def kernel(src, dst):
151154
152155 def test_vmap_of_kernel_with_input_output_aliases (self ):
153156 @functools .partial (
154- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), jnp . int32 ),
157+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((), intx ),
155158 input_output_aliases = {1 :0 },
156159 grid = ())
157160 def add (x_ref , _ , o_ref ):
@@ -163,7 +166,7 @@ def add(x_ref, _, o_ref):
163166 def test_vmap_of_kernel_with_input_output_aliases_different_axes (self ):
164167 @functools .partial (
165168 self .pallas_call ,
166- out_shape = jax .ShapeDtypeStruct ((4 ,), jnp . int32 ),
169+ out_shape = jax .ShapeDtypeStruct ((4 ,), intx ),
167170 input_output_aliases = {0 : 0 },
168171 grid = (),
169172 )
@@ -176,7 +179,7 @@ def add(x_ref, o_ref):
176179
177180 def test_vmap_of_slicing_kernel_different_axes (self ):
178181 @functools .partial (
179- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), jnp . int32 ),
182+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), intx ),
180183 grid = (2 ,))
181184 def add_one (x_ref , o_ref ):
182185 i = pl .program_id (0 )
@@ -194,7 +197,7 @@ def add_one(x_ref, o_ref):
194197
195198 def test_double_vmap_of_slicing_kernel_different_axes (self ):
196199 @functools .partial (
197- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((4 ,), jnp . float32 ),
200+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((4 ,), floatx ),
198201 grid = (4 ,))
199202 def sin (x_ref , o_ref ):
200203 i = pl .program_id (0 )
@@ -211,7 +214,7 @@ def sin(x_ref, o_ref):
211214 def test_small_large_vmap (self ):
212215 # Catches https://github.com/jax-ml/jax/issues/18361
213216 @functools .partial (
214- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), jnp . int32 ),
217+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), intx ),
215218 grid = (2 ,))
216219 def add_one (x_ref , o_ref ):
217220 o_ref [()] = x_ref [()] + 1
@@ -230,7 +233,7 @@ def add_one(x_ref, o_ref):
230233 def test_small_small_large_vmap (self ):
231234
232235 @functools .partial (
233- self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), jnp . int32 ),
236+ self .pallas_call , out_shape = jax .ShapeDtypeStruct ((2 ,), intx ),
234237 grid = (2 ,))
235238 def add_one (x_ref , o_ref ):
236239 o_ref [()] = x_ref [()] + 1
@@ -249,12 +252,6 @@ def add_one(x_ref, o_ref):
249252class PallasCallVmapInterpretTest (PallasCallVmapTest ):
250253 INTERPRET = True
251254
252- def setUp (self ):
253- super ().setUp ()
254- if jtu .test_device_matches (["cpu" ]) and jax .config .x64_enabled :
255- # TODO: assertion failures on CPU in 64-bit mode
256- self .skipTest ("On CPU the test works only in 32-bit mode" )
257-
258255
259256if __name__ == "__main__" :
260257 absltest .main ()
0 commit comments