@@ -2428,6 +2428,13 @@ def get_reduced_dtype(dtype_str, op):
24282428 return dtype_str
24292429
24302430
2431+ def get_reduce_input (dtype_str , shape ):
2432+ # limit the range of integers so that reduce ops do not overflow
2433+ low = 0 if dtype_str in uint_dtypes else - 10 if dtype_str in integral_dtypes else None
2434+ high = 10 if dtype_str in integral_dtypes else None
2435+ return numpy_random (shape , dtype_str = dtype_str , low = low , high = high )
2436+
2437+
24312438@pytest .mark .interpreter
24322439@pytest .mark .parametrize ("op, dtype_str, shape" , [(op , dtype , shape ) for op in [
24332440 'min' ,
@@ -2458,9 +2465,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24582465 patch = f'z = tl.{ op } (x, axis=0)'
24592466 kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : patch })
24602467 # input
2461- rs = RandomState (17 )
2462- # limit the range of integers so that the sum does not overflow
2463- x = numpy_random ((shape , ), dtype_str = dtype_str , rs = rs )
2468+ x = get_reduce_input (dtype_str , (shape , ))
24642469 numpy_op = {
24652470 'sum' : np .sum ,
24662471 'max' : np .max ,
@@ -2485,7 +2490,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24852490 else :
24862491 z_ref = numpy_op (x ).astype (getattr (np , z_dtype_str ))
24872492 # triton result
2488- z_tri = to_triton (numpy_random ((1 , ), dtype_str = z_dtype_str , rs = rs ), device = device , dst_type = z_tri_dtype_str )
2493+ z_tri = to_triton (numpy_random ((1 , ), dtype_str = z_dtype_str ), device = device , dst_type = z_tri_dtype_str )
24892494 kernel [(1 , )](x_tri , z_tri , BLOCK = shape , num_ctas = num_ctas )
24902495 z_tri = to_numpy (z_tri )
24912496 # compare
@@ -2582,9 +2587,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
25822587
25832588 kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : f'tl.{ op } (x, axis=AXIS, keep_dims=KEEP_DIMS)' })
25842589 # input
2585- rs = RandomState (17 )
2586- # limit the range of integers so that the sum does not overflow
2587- x = numpy_random (shape , dtype_str = dtype_str , rs = rs )
2590+ x = get_reduce_input (dtype_str , shape )
25882591 x_tri = to_triton (x , device = device )
25892592 numpy_op = {
25902593 'sum' : np .sum , 'max' : np .max , 'min' : np .min , 'argmin' : np .argmin , 'argmax' : np .argmax , 'xor_sum' :
@@ -2609,7 +2612,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
26092612
26102613 # triton result
26112614 z_shape = z_ref .shape
2612- z_tri = to_triton (numpy_random (z_shape , dtype_str = z_dtype_str , rs = rs ), device = device , dst_type = z_tri_dtype_str )
2615+ z_tri = to_triton (numpy_random (z_shape , dtype_str = z_dtype_str ), device = device , dst_type = z_tri_dtype_str )
26132616 BLOCK_K = 1 if len (shape ) == 2 else shape [2 ]
26142617 IS_3D = bool (len (shape ) == 3 )
26152618 USE_I1 = dtype_str == 'bool'
@@ -3211,8 +3214,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
32113214 temp_file .write_text (ir )
32123215 kernel = triton .compile (str (temp_file ))
32133216
3214- rs = RandomState (17 )
3215- x = numpy_random ((M , N ), dtype_str = dtype_str , rs = rs , low = 0 , high = 10 )
3217+ x = get_reduce_input (dtype_str , (M , N ))
32163218 reduce2d = 'reduce2d' in epilogue_kind
32173219 z_shape = (1 , 1 ) if reduce2d else (1 , N ) if axis == 0 else (M , 1 )
32183220 z = np .zeros (z_shape ).astype (dtype_str )
0 commit comments