@@ -2285,18 +2285,16 @@ def kernel(X, Z, BLOCK: tl.constexpr):
22852285 'min' : np .min ,
22862286 'max-with-indices' : np .max ,
22872287 'min-with-indices' : np .min ,
2288- 'argmin-tie-break-fast' : np .argmin ,
22892288 'argmin-tie-break-left' : np .argmin ,
2290- 'argmax-tie-break-fast' : np .argmax ,
22912289 'argmax-tie-break-left' : np .argmax ,
22922290 }[op ]
22932291 if 'tie-break-left' in op :
22942292 x [3 :10 ] = x [numpy_op (x )]
22952293 x_tri = to_triton (x , device = device )
22962294 # numpy result
2297- z_dtype_str = 'int32' if op in ( 'argmin' , 'argmax' ) else dtype_str
2295+ z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str
22982296 z_tri_dtype_str = z_dtype_str
2299- if op not in [ 'argmin' , 'argmax' ] and dtype_str == 'bfloat16' :
2297+ if 'tie-break-left' not in op and dtype_str == 'bfloat16' :
23002298 z_dtype_str = 'float32'
23012299 z_ref = numpy_op (x ).astype (getattr (np , z_dtype_str ))
23022300 # trunc mantissa for a fair comparison of accuracy
@@ -2316,7 +2314,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
23162314 if op == 'sum' :
23172315 np .testing .assert_allclose (z_ref , z_tri , rtol = 0.01 )
23182316 else :
2319- if op in ( 'argmin' , 'argmax' ) :
2317+ if 'tie-break-left' in op :
23202318 # argmin and argmax can have multiple valid indices.
23212319 # so instead we compare the values pointed by indices
23222320 np .testing .assert_equal (x [z_ref ], x [z_tri ])
0 commit comments