@@ -879,7 +879,10 @@ def _compare(self, other, *, f_pred, si_pred, ui_pred):
879879
880880 def max (self , other ):
881881 if ir .FloatType .isinstance (self .mlir_dtype ):
882- return self ._pointwise (arith .maximumf , other )
882+ maximumf = arith .maximumf
883+ if ir .F32Type .isinstance (self .mlir_dtype ):
884+ maximumf = self ._lift_fast_instr ("max.NaN.f32" )
885+ return self ._pointwise (maximumf , other )
883886 elif ir .IntegerType .isinstance (self .mlir_dtype ):
884887 return self ._pointwise (
885888 arith .maxsi if self .is_signed else arith .maxui , other
@@ -907,8 +910,8 @@ def exp(self, *, approx: bool = False):
907910 log2e = arith .constant (f32 , ir .FloatAttr .get (f32 , 1.4426950408889634 ))
908911 def fast_exp (x ):
909912 scaled = arith .mulf (x , log2e )
910- return llvm .inline_asm (f32 , [scaled ], "ex2.approx.f32 $0, $1;" , "=f,f" )
911- return self ._pointwise (self ._lift_fast_unary (fast_exp ))
913+ return llvm .inline_asm (f32 , [scaled ], "ex2.approx.ftz. f32 $0, $1;" , "=f,f" )
914+ return self ._pointwise (self ._lift_fast_instr (fast_exp ))
912915 return self ._pointwise (mlir_math .exp )
913916
914917 def sin (self , * , approx : bool = False ):
@@ -917,7 +920,7 @@ def sin(self, *, approx: bool = False):
917920 if approx and self .mlir_dtype != ir .F32Type .get ():
918921 raise NotImplementedError
919922 return self ._pointwise (
920- self ._lift_fast_unary ("sin.approx.f32" ) if approx else mlir_math .sin
923+ self ._lift_fast_instr ("sin.approx.f32" ) if approx else mlir_math .sin
921924 )
922925
923926 def cos (self , * , approx : bool = False ):
@@ -926,7 +929,7 @@ def cos(self, *, approx: bool = False):
926929 if approx and self .mlir_dtype != ir .F32Type .get ():
927930 raise NotImplementedError
928931 return self ._pointwise (
929- self ._lift_fast_unary ("cos.approx.f32" ) if approx else mlir_math .cos
932+ self ._lift_fast_instr ("cos.approx.f32" ) if approx else mlir_math .cos
930933 )
931934
932935 def tanh (self , * , approx : bool = False ):
@@ -935,7 +938,7 @@ def tanh(self, *, approx: bool = False):
935938 if approx and self .mlir_dtype != ir .F32Type .get ():
936939 raise NotImplementedError
937940 return self ._pointwise (
938- self ._lift_fast_unary ("tanh.approx.f32" ) if approx else mlir_math .tanh
941+ self ._lift_fast_instr ("tanh.approx.f32" ) if approx else mlir_math .tanh
939942 )
940943
941944 def rsqrt (self , * , approx : bool = False ):
@@ -944,31 +947,36 @@ def rsqrt(self, *, approx: bool = False):
944947 if approx and self .mlir_dtype != ir .F32Type .get ():
945948 raise NotImplementedError
946949 return self ._pointwise (
947- self ._lift_fast_unary ("rsqrt.approx.f32" ) if approx else mlir_math .rsqrt
950+ self ._lift_fast_instr ("rsqrt.approx.f32" ) if approx else mlir_math .rsqrt
948951 )
949952
950953 @staticmethod
951- def _lift_fast_unary (
954+ def _lift_fast_instr (
952955 instr : str | Callable [[ir .Value ], ir .Value ],
953956 ) -> Callable [[ir .Value ], ir .Value ]:
954- def fast_instr (x ):
957+ def fast_instr (* args ):
955958 f32 = ir .F32Type .get ()
956- if x .type == f32 :
959+ arg_ty = args [0 ].type
960+ assert all (a .type == arg_ty for a in args )
961+ if arg_ty == f32 :
957962 if isinstance (instr , str ):
958- return llvm .inline_asm (f32 , [x ], instr + " $0, $1;" , "=f,f" )
963+ args_ptx = ", " .join (f"${ i } " for i in range (len (args ) + 1 ))
964+ return llvm .inline_asm (
965+ f32 , args , f"{ instr } { args_ptx } ;" , "=f" + ",f" * len (args )
966+ )
959967 else :
960- return instr (x )
961- elif ir .VectorType .isinstance (x . type ):
968+ return instr (* args )
969+ elif ir .VectorType .isinstance (arg_ty ):
962970 index = ir .IndexType .get ()
963- result = llvm .mlir_undef (x . type )
964- [vec_len ] = ir .VectorType (x . type ).shape
971+ result = llvm .mlir_undef (arg_ty )
972+ [vec_len ] = ir .VectorType (arg_ty ).shape
965973 for i in range (vec_len ):
966- v = vector .extractelement (x , position = c (i , index ))
967- vr = fast_instr (v )
974+ vs = [ vector .extractelement (a , position = c (i , index )) for a in args ]
975+ vr = fast_instr (* vs )
968976 result = vector .insertelement (vr , result , position = c (i , index ))
969977 return result
970978 else :
971- raise NotImplementedError (x . type )
979+ raise NotImplementedError (arg_ty )
972980 return fast_instr
973981
974982 def bitcast (self , elt : ir .Type , * , output_is_signed : bool | None = None ):
@@ -1156,7 +1164,20 @@ def reduce_sum(self, scratch) -> ir.Value:
11561164 utils .warpgroup_barrier () # Make sure everyone is done using scratch.
11571165 return result
11581166
1159- def reduce (self , op , axis ):
1167+ def reduce (self , op : str | Callable [[ir .Value , ir .Value ], ir .Value ], axis ):
1168+ if isinstance (op , str ):
1169+ match op :
1170+ case "max" :
1171+ if ir .F32Type .isinstance (self .mlir_dtype ):
1172+ op = self ._lift_fast_instr ("max.NaN.f32" )
1173+ elif ir .FloatType .isinstance (self .mlir_dtype ):
1174+ op = arith .maximumf
1175+ elif ir .IntegerType .isinstance (self .mlir_dtype ):
1176+ op = arith .maxsi if self .is_signed else arith .maxui
1177+ else :
1178+ raise NotImplementedError (self .mlir_dtype )
1179+ case _:
1180+ raise ValueError (f"Unrecognized reduction operator: { op } " )
11601181 if self .layout != WGMMA_LAYOUT :
11611182 raise NotImplementedError (self .layout )
11621183 if axis != 1 :
@@ -1421,7 +1442,7 @@ def load_tiled(
14211442 tiled_shape = ref_ty .shape
14221443 if len (tiled_shape ) % 2 :
14231444 raise ValueError ("Tiled reference must have even rank" )
1424- tiling = Tiling ((tiled_shape [len (tiled_shape ) // 2 :],))
1445+ tiling = Tiling ((tiled_shape [len (tiled_shape ) // 2 :],))
14251446 shape = tiling .untile_shape (tiled_shape )
14261447 registers = np .full (layout .registers_shape (shape ), None , dtype = object )
14271448 reg_ty = ir .VectorType .get ((layout .vector_length ,), ref_ty .element_type )
0 commit comments