Skip to content

Commit 4bac105

Browse files
basioli-kGoogle-ML-Automation
authored andcommitted
[XLA:CPU] Use distinct element values in tiled kernel tests
Using random values is not best practice. Using inputs made out of the same elements isn't a good test for ops like transpose. PiperOrigin-RevId: 837188191
1 parent 4b4165b commit 4bac105

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

xla/backends/cpu/codegen/tiled/tiled_kernel_test.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@
2828

2929
class InputSpec:
3030

31-
def __init__(self, shape: tuple[int, ...], input_value):
31+
def __init__(self, shape: tuple[int, ...]):
32+
"""Initializes the InputSpec.
33+
34+
Args:
35+
shape: The shape of the input array.
36+
"""
3237
self.shape = shape
33-
self.input_value = input_value
3438

3539

3640
def get_random_array(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray:
@@ -58,12 +62,8 @@ def compare_kernel(
5862
cpu_testlib.JitCompiler(base_testlib.HloModuleConfig()),
5963
)
6064

61-
# Simply use a all-ones arrays as inputs to make it easy to debug the kernel
62-
# unless random inputs are requested.
6365
def get_input(spec: InputSpec):
64-
if spec.input_value is None:
65-
return get_random_array(spec.shape, dtype)
66-
return np.full(shape=spec.shape, fill_value=spec.input_value, dtype=dtype)
66+
return np.arange(np.prod(spec.shape), dtype=dtype).reshape(spec.shape)
6767

6868
inputs = [get_input(spec) for spec in input_specs]
6969

@@ -105,7 +105,7 @@ def test_slice(self):
105105
ir,
106106
"tiled_slice",
107107
1,
108-
[InputSpec((5, 5), 1)],
108+
[InputSpec((5, 5))],
109109
(5, 5),
110110
np.float32,
111111
lambda arg: arg.transpose(),
@@ -129,7 +129,7 @@ def test_strided(self):
129129
ir,
130130
"tiled_slice",
131131
1,
132-
[InputSpec((64, 64), 1)],
132+
[InputSpec((64, 64))],
133133
(4, 32),
134134
np.float32,
135135
lambda arg: arg[::21, ::2],
@@ -156,7 +156,7 @@ def test_transpose(self):
156156
ir,
157157
"tiled_transpose",
158158
8,
159-
[InputSpec((4096, 4096), 1)],
159+
[InputSpec((4096, 4096))],
160160
(4096, 4096),
161161
np.float32,
162162
lambda arg: arg.transpose(),
@@ -185,7 +185,7 @@ def test_add_tranpose(self):
185185
ir,
186186
"add_tranpose",
187187
8,
188-
[InputSpec((4096, 4096), 1)],
188+
[InputSpec((4096, 4096))],
189189
(4096, 4096),
190190
np.float32,
191191
lambda arg: arg + arg.transpose(),
@@ -213,7 +213,7 @@ def test_dot_single_tile(self):
213213
ir,
214214
"dot_single_tile",
215215
1,
216-
[InputSpec((8, 16), 1), InputSpec((16, 8), 1)],
216+
[InputSpec((8, 16)), InputSpec((16, 8))],
217217
(8, 8),
218218
np.float32,
219219
lambda lhs, rhs: lhs @ rhs,
@@ -242,7 +242,7 @@ def test_dot_scalar_output(self):
242242
ir,
243243
"test_dot_scalar_output",
244244
1,
245-
[InputSpec((8, 16), 1), InputSpec((16, 8), 1)],
245+
[InputSpec((8, 16)), InputSpec((16, 8))],
246246
(),
247247
np.float32,
248248
lambda lhs, rhs: np.tensordot(lhs, rhs, axes=[[1, 0], [0, 1]]),
@@ -275,7 +275,11 @@ def test_dot_fusion_single_tile(self):
275275
ir,
276276
"dot_fusion_single_tile",
277277
1,
278-
[InputSpec((8, 16), 1), InputSpec((8, 16), 1), InputSpec((16, 1), 1)],
278+
[
279+
InputSpec((8, 16)),
280+
InputSpec((8, 16)),
281+
InputSpec((16, 1)),
282+
],
279283
(8, 1),
280284
np.float32,
281285
lambda lhs_0, lhs_1, rhs: np.tanh((lhs_0 + lhs_1) @ rhs),
@@ -312,7 +316,7 @@ def test_reduction_add_inner(self):
312316
ir,
313317
"reduction_add_inner",
314318
4,
315-
[InputSpec((1024, 32), 1), InputSpec((1,), 0)],
319+
[InputSpec((1024, 32)), InputSpec((1,))],
316320
(1024,),
317321
np.int32,
318322
lambda input, init: np.sum(input, axis=1) + init,
@@ -348,7 +352,7 @@ def test_reduction_add_outer(self):
348352
ir,
349353
"reduction_add_outer",
350354
4,
351-
[InputSpec((1024, 32), 1), InputSpec((1,), 0)],
355+
[InputSpec((1024, 32)), InputSpec((1,))],
352356
(32,),
353357
np.float32,
354358
lambda input, init: np.sum(input, axis=0),
@@ -381,7 +385,7 @@ def test_reduction_middle(self):
381385
ir,
382386
"reduction_add_middle",
383387
1,
384-
[InputSpec((8, 4, 2), 1), InputSpec((1,), 0)],
388+
[InputSpec((8, 4, 2)), InputSpec((1,))],
385389
(8, 2),
386390
np.float32,
387391
lambda input, init: np.sum(input, axis=1),
@@ -414,7 +418,7 @@ def test_reduction_outer_inner(self):
414418
ir,
415419
"reduction_add_outer_inner",
416420
1,
417-
[InputSpec((8, 4, 2), 1), InputSpec((1,), 0)],
421+
[InputSpec((8, 4, 2)), InputSpec((1,))],
418422
(4,),
419423
np.float32,
420424
lambda input, init: np.sum(input, axis=(0, 2)),
@@ -439,7 +443,7 @@ def test_broadcast_in_dim_inner(self):
439443
ir,
440444
"broadcast_in_dim_inner",
441445
1,
442-
[InputSpec((4,), None)],
446+
[InputSpec((4,))],
443447
(32, 4),
444448
np.float32,
445449
lambda input: np.broadcast_to(input, (32, 4)),
@@ -464,7 +468,7 @@ def test_broadcast_in_dim_outer(self):
464468
ir,
465469
"broadcast_in_dim_outer",
466470
1,
467-
[InputSpec((4,), None)],
471+
[InputSpec((4,))],
468472
(4, 32),
469473
np.float32,
470474
lambda input: np.transpose(np.broadcast_to(input, (32, 4))),

0 commit comments

Comments
 (0)