Skip to content

Commit 3ba7d6d

Browse files
atalmanJokeren
andauthored
[Cherry-Pick][TEST] Consolidate input generation for reduce tests (triton-lang#7522)
Cherry-Pick of triton-lang#7477 to fix integration tests Co-authored-by: Keren Zhou <[email protected]>
1 parent 7c2ca84 commit 3ba7d6d

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

python/test/unit/language/test_core.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,6 +2428,13 @@ def get_reduced_dtype(dtype_str, op):
24282428
return dtype_str
24292429

24302430

2431+
def get_reduce_input(dtype_str, shape):
2432+
# limit the range of integers so that reduce ops do not overflow
2433+
low = 0 if dtype_str in uint_dtypes else -10 if dtype_str in integral_dtypes else None
2434+
high = 10 if dtype_str in integral_dtypes else None
2435+
return numpy_random(shape, dtype_str=dtype_str, low=low, high=high)
2436+
2437+
24312438
@pytest.mark.interpreter
24322439
@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [
24332440
'min',
@@ -2458,9 +2465,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24582465
patch = f'z = tl.{op}(x, axis=0)'
24592466
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': patch})
24602467
# input
2461-
rs = RandomState(17)
2462-
# limit the range of integers so that the sum does not overflow
2463-
x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs)
2468+
x = get_reduce_input(dtype_str, (shape, ))
24642469
numpy_op = {
24652470
'sum': np.sum,
24662471
'max': np.max,
@@ -2485,7 +2490,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24852490
else:
24862491
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
24872492
# triton result
2488-
z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str)
2493+
z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str)
24892494
kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas)
24902495
z_tri = to_numpy(z_tri)
24912496
# compare
@@ -2582,9 +2587,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
25822587

25832588
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis=AXIS, keep_dims=KEEP_DIMS)'})
25842589
# input
2585-
rs = RandomState(17)
2586-
# limit the range of integers so that the sum does not overflow
2587-
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
2590+
x = get_reduce_input(dtype_str, shape)
25882591
x_tri = to_triton(x, device=device)
25892592
numpy_op = {
25902593
'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax, 'xor_sum':
@@ -2609,7 +2612,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
26092612

26102613
# triton result
26112614
z_shape = z_ref.shape
2612-
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str)
2615+
z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str), device=device, dst_type=z_tri_dtype_str)
26132616
BLOCK_K = 1 if len(shape) == 2 else shape[2]
26142617
IS_3D = bool(len(shape) == 3)
26152618
USE_I1 = dtype_str == 'bool'
@@ -3211,8 +3214,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
32113214
temp_file.write_text(ir)
32123215
kernel = triton.compile(str(temp_file))
32133216

3214-
rs = RandomState(17)
3215-
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
3217+
x = get_reduce_input(dtype_str, (M, N))
32163218
reduce2d = 'reduce2d' in epilogue_kind
32173219
z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1)
32183220
z = np.zeros(z_shape).astype(dtype_str)

0 commit comments

Comments
 (0)