2828
2929class InputSpec :
3030
31- def __init__ (self , shape : tuple [int , ...], input_value ):
31+ def __init__ (self , shape : tuple [int , ...]):
32+ """Initializes the InputSpec.
33+
34+ Args:
35+ shape: The shape of the input array.
36+ """
3237 self .shape = shape
33- self .input_value = input_value
3438
3539
3640def get_random_array (shape : tuple [int , ...], dtype : np .dtype ) -> np .ndarray :
@@ -58,12 +62,8 @@ def compare_kernel(
5862 cpu_testlib .JitCompiler (base_testlib .HloModuleConfig ()),
5963 )
6064
61- # Simply use a all-ones arrays as inputs to make it easy to debug the kernel
62- # unless random inputs are requested.
6365 def get_input (spec : InputSpec ):
64- if spec .input_value is None :
65- return get_random_array (spec .shape , dtype )
66- return np .full (shape = spec .shape , fill_value = spec .input_value , dtype = dtype )
66+ return np .arange (np .prod (spec .shape ), dtype = dtype ).reshape (spec .shape )
6767
6868 inputs = [get_input (spec ) for spec in input_specs ]
6969
@@ -105,7 +105,7 @@ def test_slice(self):
105105 ir ,
106106 "tiled_slice" ,
107107 1 ,
108- [InputSpec ((5 , 5 ), 1 )],
108+ [InputSpec ((5 , 5 ))],
109109 (5 , 5 ),
110110 np .float32 ,
111111 lambda arg : arg .transpose (),
@@ -129,7 +129,7 @@ def test_strided(self):
129129 ir ,
130130 "tiled_slice" ,
131131 1 ,
132- [InputSpec ((64 , 64 ), 1 )],
132+ [InputSpec ((64 , 64 ))],
133133 (4 , 32 ),
134134 np .float32 ,
135135 lambda arg : arg [::21 , ::2 ],
@@ -156,7 +156,7 @@ def test_transpose(self):
156156 ir ,
157157 "tiled_transpose" ,
158158 8 ,
159- [InputSpec ((4096 , 4096 ), 1 )],
159+ [InputSpec ((4096 , 4096 ))],
160160 (4096 , 4096 ),
161161 np .float32 ,
162162 lambda arg : arg .transpose (),
@@ -185,7 +185,7 @@ def test_add_tranpose(self):
185185 ir ,
186186 "add_tranpose" ,
187187 8 ,
188- [InputSpec ((4096 , 4096 ), 1 )],
188+ [InputSpec ((4096 , 4096 ))],
189189 (4096 , 4096 ),
190190 np .float32 ,
191191 lambda arg : arg + arg .transpose (),
@@ -213,7 +213,7 @@ def test_dot_single_tile(self):
213213 ir ,
214214 "dot_single_tile" ,
215215 1 ,
216- [InputSpec ((8 , 16 ), 1 ), InputSpec ((16 , 8 ), 1 )],
216+ [InputSpec ((8 , 16 )), InputSpec ((16 , 8 ))],
217217 (8 , 8 ),
218218 np .float32 ,
219219 lambda lhs , rhs : lhs @ rhs ,
@@ -242,7 +242,7 @@ def test_dot_scalar_output(self):
242242 ir ,
243243 "test_dot_scalar_output" ,
244244 1 ,
245- [InputSpec ((8 , 16 ), 1 ), InputSpec ((16 , 8 ), 1 )],
245+ [InputSpec ((8 , 16 )), InputSpec ((16 , 8 ))],
246246 (),
247247 np .float32 ,
248248 lambda lhs , rhs : np .tensordot (lhs , rhs , axes = [[1 , 0 ], [0 , 1 ]]),
@@ -275,7 +275,11 @@ def test_dot_fusion_single_tile(self):
275275 ir ,
276276 "dot_fusion_single_tile" ,
277277 1 ,
278- [InputSpec ((8 , 16 ), 1 ), InputSpec ((8 , 16 ), 1 ), InputSpec ((16 , 1 ), 1 )],
278+ [
279+ InputSpec ((8 , 16 )),
280+ InputSpec ((8 , 16 )),
281+ InputSpec ((16 , 1 )),
282+ ],
279283 (8 , 1 ),
280284 np .float32 ,
281285 lambda lhs_0 , lhs_1 , rhs : np .tanh ((lhs_0 + lhs_1 ) @ rhs ),
@@ -312,7 +316,7 @@ def test_reduction_add_inner(self):
312316 ir ,
313317 "reduction_add_inner" ,
314318 4 ,
315- [InputSpec ((1024 , 32 ), 1 ), InputSpec ((1 ,), 0 )],
319+ [InputSpec ((1024 , 32 )), InputSpec ((1 ,))],
316320 (1024 ,),
317321 np .int32 ,
318322 lambda input , init : np .sum (input , axis = 1 ) + init ,
@@ -348,7 +352,7 @@ def test_reduction_add_outer(self):
348352 ir ,
349353 "reduction_add_outer" ,
350354 4 ,
351- [InputSpec ((1024 , 32 ), 1 ), InputSpec ((1 ,), 0 )],
355+ [InputSpec ((1024 , 32 )), InputSpec ((1 ,))],
352356 (32 ,),
353357 np .float32 ,
354358 lambda input , init : np .sum (input , axis = 0 ),
@@ -381,7 +385,7 @@ def test_reduction_middle(self):
381385 ir ,
382386 "reduction_add_middle" ,
383387 1 ,
384- [InputSpec ((8 , 4 , 2 ), 1 ), InputSpec ((1 ,), 0 )],
388+ [InputSpec ((8 , 4 , 2 )), InputSpec ((1 ,))],
385389 (8 , 2 ),
386390 np .float32 ,
387391 lambda input , init : np .sum (input , axis = 1 ),
@@ -414,7 +418,7 @@ def test_reduction_outer_inner(self):
414418 ir ,
415419 "reduction_add_outer_inner" ,
416420 1 ,
417- [InputSpec ((8 , 4 , 2 ), 1 ), InputSpec ((1 ,), 0 )],
421+ [InputSpec ((8 , 4 , 2 )), InputSpec ((1 ,))],
418422 (4 ,),
419423 np .float32 ,
420424 lambda input , init : np .sum (input , axis = (0 , 2 )),
@@ -439,7 +443,7 @@ def test_broadcast_in_dim_inner(self):
439443 ir ,
440444 "broadcast_in_dim_inner" ,
441445 1 ,
442- [InputSpec ((4 ,), None )],
446+ [InputSpec ((4 ,))],
443447 (32 , 4 ),
444448 np .float32 ,
445449 lambda input : np .broadcast_to (input , (32 , 4 )),
@@ -464,7 +468,7 @@ def test_broadcast_in_dim_outer(self):
464468 ir ,
465469 "broadcast_in_dim_outer" ,
466470 1 ,
467- [InputSpec ((4 ,), None )],
471+ [InputSpec ((4 ,))],
468472 (4 , 32 ),
469473 np .float32 ,
470474 lambda input : np .transpose (np .broadcast_to (input , (32 , 4 ))),
0 commit comments