@@ -700,7 +700,7 @@ def __neg__(self):
700700
701701 def __add__ (self , other ):
702702 if ir .FloatType .isinstance (self .mlir_dtype ):
703- return self ._pointwise (arith . addf , other )
703+ return self ._pointwise (addf , other )
704704 elif ir .IntegerType .isinstance (self .mlir_dtype ):
705705 return self ._pointwise (arith .addi , other )
706706 else :
@@ -711,7 +711,7 @@ def __radd__(self, other):
711711
712712 def __mul__ (self , other ):
713713 if ir .FloatType .isinstance (self .mlir_dtype ):
714- return self ._pointwise (arith . mulf , other )
714+ return self ._pointwise (mulf , other )
715715 elif ir .IntegerType .isinstance (self .mlir_dtype ):
716716 return self ._pointwise (arith .muli , other )
717717 else :
@@ -722,15 +722,15 @@ def __rmul__(self, other):
722722
723723 def __sub__ (self , other ):
724724 if ir .FloatType .isinstance (self .mlir_dtype ):
725- return self ._pointwise (arith . subf , other )
725+ return self ._pointwise (subf , other )
726726 elif ir .IntegerType .isinstance (self .mlir_dtype ):
727727 return self ._pointwise (arith .subi , other )
728728 else :
729729 return NotImplemented
730730
731731 def __rsub__ (self , other ):
732732 if ir .FloatType .isinstance (self .mlir_dtype ):
733- return self ._pointwise (lambda s , o : arith . subf (o , s ), other )
733+ return self ._pointwise (lambda s , o : subf (o , s ), other )
734734 elif ir .IntegerType .isinstance (self .mlir_dtype ):
735735 return self ._pointwise (lambda s , o : arith .subi (o , s ), other )
736736 else :
@@ -904,16 +904,20 @@ def exp(self, *, approx: bool = False):
904904 if not ir .FloatType .isinstance (self .mlir_dtype ):
905905 raise NotImplementedError
906906 if approx :
907- f32 = ir .F32Type .get ()
908- if self .mlir_dtype != f32 :
909- raise NotImplementedError
910- log2e = arith .constant (f32 , ir .FloatAttr .get (f32 , 1.4426950408889634 ))
911- def fast_exp (x ):
912- scaled = arith .mulf (x , log2e )
913- return llvm .inline_asm (f32 , [scaled ], "ex2.approx.ftz.f32 $0, $1;" , "=f,f" )
914- return self ._pointwise (self ._lift_fast_instr (fast_exp ))
907+ dtype = self .mlir_dtype
908+ log2e = arith .constant (dtype , ir .FloatAttr .get (dtype , 1.4426950408889634 ))
909+ return (self * log2e ).exp2 ()
915910 return self ._pointwise (mlir_math .exp )
916911
912+ def exp2 (self , * , approx : bool = False ):
913+ if not ir .FloatType .isinstance (self .mlir_dtype ):
914+ raise NotImplementedError
915+ if approx :
916+ if not ir .F32Type .isinstance (self .mlir_dtype ):
917+ raise NotImplementedError (self .mlir_dtype )
918+ return self ._pointwise (self ._lift_fast_instr ("ex2.approx.ftz.f32" ))
919+ return self ._pointwise (mlir_math .exp2 )
920+
917921 def sin (self , * , approx : bool = False ):
918922 if not ir .FloatType .isinstance (self .mlir_dtype ):
919923 raise NotImplementedError
@@ -1125,7 +1129,7 @@ def upcast_to_bf16(reg, high):
11251129 # NOTE: scratch can be reused immediately once this function returns.
11261130 def reduce_sum (self , scratch ) -> ir .Value :
11271131 if ir .FloatType .isinstance (self .mlir_dtype ):
1128- op = arith . addf
1132+ op = addf
11291133 elif ir .IntegerType .isinstance (self .mlir_dtype ):
11301134 op = arith .addi
11311135 else :
@@ -1167,6 +1171,13 @@ def reduce_sum(self, scratch) -> ir.Value:
11671171 def reduce (self , op : str | Callable [[ir .Value , ir .Value ], ir .Value ], axis ):
11681172 if isinstance (op , str ):
11691173 match op :
1174+ case "add" :
1175+ if ir .FloatType .isinstance (self .mlir_dtype ):
1176+ op = addf
1177+ elif ir .IntegerType .isinstance (self .mlir_dtype ):
1178+ op = arith .addi
1179+ else :
1180+ raise NotImplementedError (self .mlir_dtype )
11701181 case "max" :
11711182 if ir .F32Type .isinstance (self .mlir_dtype ):
11721183 op = self ._lift_fast_instr ("max.NaN.f32" )
@@ -1653,3 +1664,15 @@ def tree_unflatten(cls, aux, flat_registers):
16531664 layout , reg_shape , is_signed = aux
16541665 registers = np .asarray (flat_registers , dtype = object ).reshape (reg_shape )
16551666 return cls (_registers = registers , _layout = layout , _is_signed = is_signed )
1667+
1668+
1669+ # We allow contractions, to potentially take advantage of FMA instructions.
1670+ # They can change the results, but the precision should only increase.
1671+ def addf (a : ir .Value , b : ir .Value ):
1672+ return arith .addf (a , b , fastmath = arith .FastMathFlags .contract )
1673+
1674+ def subf (a : ir .Value , b : ir .Value ):
1675+ return arith .subf (a , b , fastmath = arith .FastMathFlags .contract )
1676+
1677+ def mulf (a : ir .Value , b : ir .Value ):
1678+ return arith .mulf (a , b , fastmath = arith .FastMathFlags .contract )
0 commit comments