4848 )
4949 hp .settings .load_profile ("deterministic" )
5050
51-
5251 def seed_strategy () -> hps .SearchStrategy [int ]:
5352 return hps .integers (min_value = 0 , max_value = 4 )
5453
55-
5654 @hps .composite
5755 def group_strategy (
5856 draw : hps .DrawFn ,
@@ -73,7 +71,6 @@ def group_strategy(
7371 )
7472 return num_groups , group_stride
7573
76-
7774 @hps .composite
7875 def group_sizes_strategy (
7976 draw : hps .DrawFn , m : int , num_groups : int
@@ -97,19 +94,12 @@ def group_sizes_strategy(
9794 starts = np .concatenate ([np .zeros (1 , dtype = np .int32 ), ends_no_final ])
9895 return jnp .array (ends - starts , dtype = jnp .int32 )
9996
100-
10197 GROUPED_MATMUL_TESTS = (
102- (128 , 128 , 128 ),
103- (256 , 128 , 128 ),
104- (128 , 256 , 128 ),
105- (128 , 128 , 256 ),
106- (256 , 128 , 512 ),
107- (512 , 128 , 128 ),
108- (512 , 2048 , 128 ),
98+ (128 , 128 , 128 ), # Small
99+ (512 , 2048 , 256 ), # Big
109100 (128 , 8 , 16 ), # Test partial tiles.
110101 )
111102
112-
113103 def random_dense (
114104 shape : tuple [int , ...],
115105 key : jax .Array ,
@@ -121,7 +111,6 @@ def random_dense(
121111 x = jax .random .uniform (key , shape , dtype , minval = - limit , maxval = limit ) # pylint: disable=invalid-unary-operand-type
122112 return x .astype (jnp .bfloat16 ).astype (dtype )
123113
124-
125114 def dot (
126115 lhs : jnp .ndarray ,
127116 rhs : jnp .ndarray ,
@@ -133,7 +122,6 @@ def dot(
133122 rhs = jnp .transpose (rhs ) if transpose_rhs else rhs
134123 return jax .lax .dot (lhs , rhs , preferred_element_type = preferred_element_type )
135124
136-
137125 def reference_gmm (
138126 lhs : jnp .ndarray ,
139127 rhs : jnp .ndarray ,
@@ -154,7 +142,6 @@ def reference_gmm(
154142 start += group_sizes [i ]
155143 return jnp .concatenate (out , axis = 0 )
156144
157-
158145 def with_dtype_arguments (xs : tuple [Any , ...]) -> tuple [Any , ...]:
159146 dtypes = [jnp .float32 , jnp .bfloat16 ]
160147
@@ -164,7 +151,6 @@ def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]:
164151 result .append (x + dtypes_tuple )
165152 return tuple (result )
166153
167-
168154 def with_transpose_argument (xs : tuple [Any , ...]) -> tuple [Any , ...]:
169155 flags = [False , True ]
170156 result = []
@@ -173,7 +159,6 @@ def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]:
173159 result .append (x + (flag ,))
174160 return tuple (result )
175161
176-
177162 def tolerances (
178163 lhs_dtype : jnp .dtype , rhs_dtype : jnp .dtype , out_dtype : jnp .dtype
179164 ) -> tuple [float , float ]:
@@ -185,7 +170,6 @@ def tolerances(
185170 return 1e-3 , 1e-2 # atol, rtol
186171 return 1e-3 , 1e-5 # atol, rtol
187172
188-
189173 # TODO(tgale): Fix errors with strict dtype promotion.
190174 @jtu .with_config (jax_numpy_dtype_promotion = "standard" )
191175 class GroupedMatmulTest (jtu .JaxTestCase ):
@@ -218,15 +202,16 @@ def gmm_test(
218202 m : int ,
219203 k : int ,
220204 n : int ,
221- lhs_dtype : jnp .dtype ,
222- rhs_dtype : jnp .dtype ,
223- out_dtype : jnp .dtype ,
224- transpose_rhs : bool ,
225205 data : hps .SearchStrategy [hps .DataObject ],
226206 interpret : bool = False ,
227207 ):
228208 seed = data .draw (seed_strategy ())
229209 num_groups , _ = data .draw (group_strategy (max_stride = 1 ))
210+ lhs_dtype , rhs_dtype , out_dtype = [
211+ data .draw (hps .sampled_from ([jnp .float32 , jnp .bfloat16 ]))
212+ for _ in range (3 )
213+ ]
214+ transpose_rhs = data .draw (hps .booleans ())
230215
231216 key = jax .random .key (seed )
232217 k1 , k2 = jax .random .split (key , 2 )
@@ -270,64 +255,52 @@ def reference_fn(lhs, rhs, group_sizes, preferred_element_type):
270255 self .assert_allclose (grad_lhs , expected_grad_lhs , atol = atol , rtol = rtol )
271256 self .assert_allclose (grad_rhs , expected_grad_rhs , atol = atol , rtol = rtol )
272257
273- @parameterized .parameters (
274- * with_transpose_argument (with_dtype_arguments (GROUPED_MATMUL_TESTS ))
275- )
258+ @parameterized .parameters (* GROUPED_MATMUL_TESTS )
276259 @hp .given (hps .data ())
277260 def test_gmm (
278261 self ,
279262 m : int ,
280263 k : int ,
281264 n : int ,
282- lhs_dtype : jnp .dtype ,
283- rhs_dtype : jnp .dtype ,
284- out_dtype : jnp .dtype ,
285- transpose_rhs : bool ,
286265 data : hps .SearchStrategy [hps .DataObject ],
287266 ):
288- self .gmm_test (m , k , n , lhs_dtype , rhs_dtype , out_dtype , transpose_rhs , data )
267+ self .gmm_test (m , k , n , data )
289268
290269 # NOTE: Run fewer tests with interpret mode. We just want to sanity check that
291270 # changes do not break running these kernels with interpret=True.
292- @parameterized .parameters (* with_dtype_arguments ( GROUPED_MATMUL_TESTS [0 :1 ]) )
271+ @parameterized .parameters (* GROUPED_MATMUL_TESTS [0 :1 ])
293272 @hp .given (hps .data ())
294273 def test_gmm_interpret (
295274 self ,
296275 m : int ,
297276 k : int ,
298277 n : int ,
299- lhs_dtype : jnp .dtype ,
300- rhs_dtype : jnp .dtype ,
301- out_dtype : jnp .dtype ,
302278 data : hps .SearchStrategy [hps .DataObject ],
303279 ):
304280 self .skipTest ("interpret mode with dynamic grids is unsupported" )
305281 self .gmm_test (
306282 m ,
307283 k ,
308284 n ,
309- lhs_dtype ,
310- rhs_dtype ,
311- out_dtype ,
312- transpose_rhs = False ,
313285 data = data ,
314286 interpret = True ,
315287 )
316288
317- @parameterized .parameters (* with_dtype_arguments ( GROUPED_MATMUL_TESTS ) )
289+ @parameterized .parameters (* GROUPED_MATMUL_TESTS )
318290 @hp .given (hps .data ())
319291 def test_gmm_sharded_groups (
320292 self ,
321293 m : int ,
322294 k : int ,
323295 n : int ,
324- lhs_dtype : jnp .dtype ,
325- rhs_dtype : jnp .dtype ,
326- out_dtype : jnp .dtype ,
327296 data : hps .SearchStrategy [hps .DataObject ],
328297 ):
329298 seed = data .draw (seed_strategy ())
330299 num_groups , group_stride = data .draw (group_strategy ())
300+ lhs_dtype , rhs_dtype , out_dtype = [
301+ data .draw (hps .sampled_from ([jnp .float32 , jnp .bfloat16 ]))
302+ for _ in range (3 )
303+ ]
331304
332305 key = jax .random .key (seed )
333306 k1 , k2 = jax .random .split (key , 2 )
0 commit comments