Skip to content

Commit b1301d6

Browse files
authored
[TEST] Remove dead code branches from test_reduce1d (#5701)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 8427f69 commit b1301d6

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2240,18 +2240,16 @@ def kernel(X, Z, BLOCK: tl.constexpr):
22402240
'min': np.min,
22412241
'max-with-indices': np.max,
22422242
'min-with-indices': np.min,
2243-
'argmin-tie-break-fast': np.argmin,
22442243
'argmin-tie-break-left': np.argmin,
2245-
'argmax-tie-break-fast': np.argmax,
22462244
'argmax-tie-break-left': np.argmax,
22472245
}[op]
22482246
if 'tie-break-left' in op:
22492247
x[3:10] = x[numpy_op(x)]
22502248
x_tri = to_triton(x, device=device)
22512249
# numpy result
2252-
z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str
2250+
z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str
22532251
z_tri_dtype_str = z_dtype_str
2254-
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
2252+
if 'tie-break-left' not in op and dtype_str == 'bfloat16':
22552253
z_dtype_str = 'float32'
22562254
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
22572255
# trunc mantissa for a fair comparison of accuracy
@@ -2267,7 +2265,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
22672265
if op == 'sum':
22682266
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
22692267
else:
2270-
if op in ('argmin', 'argmax'):
2268+
if 'tie-break-left' in op:
22712269
# argmin and argmax can have multiple valid indices.
22722270
# so instead we compare the values pointed by indices
22732271
np.testing.assert_equal(x[z_ref], x[z_tri])

0 commit comments

Comments
 (0)