Skip to content

Commit 2b5505c

Browse files
authored
[TEST] Consolidate input generation for reduce tests (#7477)
1 parent c0175fa commit 2b5505c

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

python/test/unit/language/test_core.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,6 +2485,13 @@ def get_reduced_dtype(dtype_str, op):
24852485
return dtype_str
24862486

24872487

2488+
def get_reduce_input(dtype_str, shape):
2489+
# limit the range of integers so that reduce ops do not overflow
2490+
low = 0 if dtype_str in uint_dtypes else -10 if dtype_str in integral_dtypes else None
2491+
high = 10 if dtype_str in integral_dtypes else None
2492+
return numpy_random(shape, dtype_str=dtype_str, low=low, high=high)
2493+
2494+
24882495
@pytest.mark.interpreter
24892496
@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [
24902497
'min',
@@ -2515,14 +2522,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
25152522
patch = f'z = tl.{op}(x, axis=0)'
25162523
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch})
25172524
# input
2518-
rs = RandomState(17)
2519-
# limit the range of integers so that the sum does not overflow
2520-
if dtype_str in integral_dtypes:
2521-
low = 0 if dtype_str in uint_dtypes else -100
2522-
high = 100
2523-
x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs, low=low, high=high)
2524-
else:
2525-
x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs)
2525+
x = get_reduce_input(dtype_str, (shape, ))
25262526
numpy_op = {
25272527
'sum': np.sum,
25282528
'max': np.max,
@@ -2547,7 +2547,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
25472547
else:
25482548
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
25492549
# triton result
2550-
z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str)
2550+
z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str)
25512551
kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas)
25522552
z_tri = to_numpy(z_tri)
25532553
# compare
@@ -2644,9 +2644,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
26442644

26452645
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'})
26462646
# input
2647-
rs = RandomState(17)
2648-
# limit the range of integers so that the sum does not overflow
2649-
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
2647+
x = get_reduce_input(dtype_str, shape)
26502648
x_tri = to_triton(x, device=device)
26512649
numpy_op = {
26522650
'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax, 'xor_sum':
@@ -2671,7 +2669,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
26712669

26722670
# triton result
26732671
z_shape = z_ref.shape
2674-
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str)
2672+
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str)
26752673
BLOCK_K = 1 if len(shape) == 2 else shape[2]
26762674
IS_3D = bool(len(shape) == 3)
26772675
USE_I1 = dtype_str == 'bool'
@@ -3319,8 +3317,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
33193317
temp_file.write_text(ir)
33203318
kernel = triton.compile(str(temp_file))
33213319

3322-
rs = RandomState(17)
3323-
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
3320+
x = get_reduce_input(dtype_str, (M, N))
33243321
reduce2d = 'reduce2d' in epilogue_kind
33253322
z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1)
33263323
z = np.zeros(z_shape).astype(dtype_str)

0 commit comments

Comments
 (0)