Skip to content

Commit ec278de

Browse files
authored
Remove dead code branches from test_reduce1d (#3265)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 8c8a722 commit ec278de

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,18 +2285,16 @@ def kernel(X, Z, BLOCK: tl.constexpr):
22852285
'min': np.min,
22862286
'max-with-indices': np.max,
22872287
'min-with-indices': np.min,
2288-
'argmin-tie-break-fast': np.argmin,
22892288
'argmin-tie-break-left': np.argmin,
2290-
'argmax-tie-break-fast': np.argmax,
22912289
'argmax-tie-break-left': np.argmax,
22922290
}[op]
22932291
if 'tie-break-left' in op:
22942292
x[3:10] = x[numpy_op(x)]
22952293
x_tri = to_triton(x, device=device)
22962294
# numpy result
2297-
z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str
2295+
z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str
22982296
z_tri_dtype_str = z_dtype_str
2299-
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
2297+
if 'tie-break-left' not in op and dtype_str == 'bfloat16':
23002298
z_dtype_str = 'float32'
23012299
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
23022300
# trunc mantissa for a fair comparison of accuracy
@@ -2316,7 +2314,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
23162314
if op == 'sum':
23172315
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
23182316
else:
2319-
if op in ('argmin', 'argmax'):
2317+
if 'tie-break-left' in op:
23202318
# argmin and argmax can have multiple valid indices.
23212319
# so instead we compare the values pointed by indices
23222320
np.testing.assert_equal(x[z_ref], x[z_tri])

0 commit comments

Comments
 (0)