@@ -2240,18 +2240,16 @@ def kernel(X, Z, BLOCK: tl.constexpr):
22402240 'min' : np .min ,
22412241 'max-with-indices' : np .max ,
22422242 'min-with-indices' : np .min ,
2243- 'argmin-tie-break-fast' : np .argmin ,
22442243 'argmin-tie-break-left' : np .argmin ,
2245- 'argmax-tie-break-fast' : np .argmax ,
22462244 'argmax-tie-break-left' : np .argmax ,
22472245 }[op ]
22482246 if 'tie-break-left' in op :
22492247 x [3 :10 ] = x [numpy_op (x )]
22502248 x_tri = to_triton (x , device = device )
22512249 # numpy result
2252- z_dtype_str = 'int32' if op in ( 'argmin' , 'argmax' ) else dtype_str
2250+ z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str
22532251 z_tri_dtype_str = z_dtype_str
2254- if op not in [ 'argmin' , 'argmax' ] and dtype_str == 'bfloat16' :
2252+ if 'tie-break-left' not in op and dtype_str == 'bfloat16' :
22552253 z_dtype_str = 'float32'
22562254 z_ref = numpy_op (x ).astype (getattr (np , z_dtype_str ))
22572255 # trunc mantissa for a fair comparison of accuracy
@@ -2267,7 +2265,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
22672265 if op == 'sum' :
22682266 np .testing .assert_allclose (z_ref , z_tri , rtol = 0.01 )
22692267 else :
2270- if op in ( 'argmin' , 'argmax' ) :
2268+ if 'tie-break-left' in op :
22712269 # argmin and argmax can have multiple valid indices.
22722270 # so instead we compare the values pointed by indices
22732271 np .testing .assert_equal (x [z_ref ], x [z_tri ])
0 commit comments