@@ -2485,6 +2485,13 @@ def get_reduced_dtype(dtype_str, op):
2485
2485
return dtype_str
2486
2486
2487
2487
2488
+ def get_reduce_input (dtype_str , shape ):
2489
+ # limit the range of integers so that reduce ops do not overflow
2490
+ low = 0 if dtype_str in uint_dtypes else - 10 if dtype_str in integral_dtypes else None
2491
+ high = 10 if dtype_str in integral_dtypes else None
2492
+ return numpy_random (shape , dtype_str = dtype_str , low = low , high = high )
2493
+
2494
+
2488
2495
@pytest .mark .interpreter
2489
2496
@pytest .mark .parametrize ("op, dtype_str, shape" , [(op , dtype , shape ) for op in [
2490
2497
'min' ,
@@ -2515,14 +2522,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
2515
2522
patch = f'z = tl.{ op } (x, axis=0)'
2516
2523
kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : patch })
2517
2524
# input
2518
- rs = RandomState (17 )
2519
- # limit the range of integers so that the sum does not overflow
2520
- if dtype_str in integral_dtypes :
2521
- low = 0 if dtype_str in uint_dtypes else - 100
2522
- high = 100
2523
- x = numpy_random ((shape , ), dtype_str = dtype_str , rs = rs , low = low , high = high )
2524
- else :
2525
- x = numpy_random ((shape , ), dtype_str = dtype_str , rs = rs )
2525
+ x = get_reduce_input (dtype_str , (shape , ))
2526
2526
numpy_op = {
2527
2527
'sum' : np .sum ,
2528
2528
'max' : np .max ,
@@ -2547,7 +2547,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
2547
2547
else :
2548
2548
z_ref = numpy_op (x ).astype (getattr (np , z_dtype_str ))
2549
2549
# triton result
2550
- z_tri = to_triton (numpy_random ((1 , ), dtype_str = z_dtype_str , rs = rs ), device = device , dst_type = z_tri_dtype_str )
2550
+ z_tri = to_triton (numpy_random ((1 , ), dtype_str = z_dtype_str ), device = device , dst_type = z_tri_dtype_str )
2551
2551
kernel [(1 , )](x_tri , z_tri , BLOCK = shape , num_ctas = num_ctas )
2552
2552
z_tri = to_numpy (z_tri )
2553
2553
# compare
@@ -2644,9 +2644,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
2644
2644
2645
2645
kernel = patch_kernel (kernel , {'GENERATE_TEST_HERE' : f'tl.{ op } (x, axis=AXIS, keep_dims=KEEP_DIMS)' })
2646
2646
# input
2647
- rs = RandomState (17 )
2648
- # limit the range of integers so that the sum does not overflow
2649
- x = numpy_random (shape , dtype_str = dtype_str , rs = rs )
2647
+ x = get_reduce_input (dtype_str , shape )
2650
2648
x_tri = to_triton (x , device = device )
2651
2649
numpy_op = {
2652
2650
'sum' : np .sum , 'max' : np .max , 'min' : np .min , 'argmin' : np .argmin , 'argmax' : np .argmax , 'xor_sum' :
@@ -2671,7 +2669,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const
2671
2669
2672
2670
# triton result
2673
2671
z_shape = z_ref .shape
2674
- z_tri = to_triton (numpy_random (z_shape , dtype_str = z_dtype_str , rs = rs ), device = device , dst_type = z_tri_dtype_str )
2672
+ z_tri = to_triton (numpy_random (z_shape , dtype_str = z_dtype_str ), device = device , dst_type = z_tri_dtype_str )
2675
2673
BLOCK_K = 1 if len (shape ) == 2 else shape [2 ]
2676
2674
IS_3D = bool (len (shape ) == 3 )
2677
2675
USE_I1 = dtype_str == 'bool'
@@ -3319,8 +3317,7 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
3319
3317
temp_file .write_text (ir )
3320
3318
kernel = triton .compile (str (temp_file ))
3321
3319
3322
- rs = RandomState (17 )
3323
- x = numpy_random ((M , N ), dtype_str = dtype_str , rs = rs , low = 0 , high = 10 )
3320
+ x = get_reduce_input (dtype_str , (M , N ))
3324
3321
reduce2d = 'reduce2d' in epilogue_kind
3325
3322
z_shape = (1 , 1 ) if reduce2d else (1 , N ) if axis == 0 else (M , 1 )
3326
3323
z = np .zeros (z_shape ).astype (dtype_str )
0 commit comments