@@ -906,16 +906,25 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
906906 return self .lhs_precision_type
907907
908908 @property
909- def accumulation_type (self ) -> DTypeLike | None :
909+ def accumulation_type (self ) -> DTypeLike | tuple [ DTypeLike , ...] | None :
910910 match self :
911911 case (
912912 DotAlgorithmPreset .DEFAULT |
913913 DotAlgorithmPreset .ANY_F8_ANY_F8_ANY |
914914 DotAlgorithmPreset .ANY_F8_ANY_F8_ANY_FAST_ACCUM
915915 ):
916916 return None
917+ case (
918+ DotAlgorithmPreset .ANY_F8_ANY_F8_F32 |
919+ DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM
920+ ):
921+ return (np .float32 , np .float16 , dtypes .bfloat16 , dtypes .float8_e4m3fn ,
922+ dtypes .float8_e5m2 , dtypes .float8_e5m2fnuz ,
923+ dtypes .float8_e4m3fnuz , dtypes .float8_e4m3b11fnuz )
917924 case DotAlgorithmPreset .F16_F16_F16 :
918925 return np .float16
926+ case DotAlgorithmPreset .F16_F16_F32 :
927+ return (np .float32 , np .float16 )
919928 case DotAlgorithmPreset .BF16_BF16_BF16 :
920929 return dtypes .bfloat16
921930 case DotAlgorithmPreset .F64_F64_F64 :
@@ -3619,6 +3628,37 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
36193628 return precision ._convert_to_hlo_attr (lhs_dtype , rhs_dtype )
36203629
36213630
3631+ def get_algorithm_compute_types (
3632+ algorithm : DotAlgorithm | DotAlgorithmPreset ,
3633+ lhs_dtype : DTypeLike ,
3634+ rhs_dtype : DTypeLike ,
3635+ out_dtype : DTypeLike | None = None ,
3636+ ) -> tuple [DTypeLike | None , DTypeLike | None , DTypeLike | None ]:
3637+ def maybe_convert_dtype (input_dtype , target_dtype ):
3638+ if target_dtype is None :
3639+ return input_dtype
3640+ if not isinstance (target_dtype , tuple ):
3641+ target_dtype = (target_dtype ,)
3642+ if any (input_dtype == d for d in target_dtype ):
3643+ return input_dtype
3644+ return target_dtype [0 ]
3645+ if algorithm == DotAlgorithmPreset .BF16_BF16_F32 :
3646+ lhs_dtype = maybe_convert_dtype (lhs_dtype , algorithm .lhs_precision_type )
3647+ rhs_dtype = maybe_convert_dtype (rhs_dtype , algorithm .rhs_precision_type )
3648+ if lhs_dtype == dtypes .bfloat16 :
3649+ out_dtype = maybe_convert_dtype (out_dtype ,
3650+ (np .float32 , dtypes .bfloat16 ))
3651+ else :
3652+ out_dtype = maybe_convert_dtype (out_dtype , np .float32 )
3653+ return lhs_dtype , rhs_dtype , out_dtype
3654+ else :
3655+ return (
3656+ maybe_convert_dtype (lhs_dtype , algorithm .lhs_precision_type ),
3657+ maybe_convert_dtype (rhs_dtype , algorithm .rhs_precision_type ),
3658+ maybe_convert_dtype (out_dtype , algorithm .accumulation_type ),
3659+ )
3660+
3661+
36223662def _dot_general_lower (ctx , lhs , rhs , * , dimension_numbers ,
36233663 precision , preferred_element_type : np .dtype | None ,
36243664 out_type , platform : str = "default" ):
@@ -3656,20 +3696,17 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
36563696 # If an explicit algorithm was specified, we always cast the input types to
36573697 # the correct types.
36583698 def maybe_convert_dtype (operand , operand_aval , target_dtype ):
3659- if target_dtype is None :
3660- return operand , operand_aval .dtype
3661- if not isinstance (target_dtype , tuple ):
3662- target_dtype = (target_dtype ,)
3663- if any (operand_aval .dtype == d for d in target_dtype ):
3664- return operand , operand_aval .dtype
3665- aval = core .ShapedArray (operand_aval .shape , target_dtype [0 ])
3666- return mlir .convert_hlo (ctx , operand , operand_aval , aval ), target_dtype [0 ]
3667-
3668- lhs , lhs_dtype = maybe_convert_dtype (lhs , lhs_aval , precision .lhs_precision_type )
3669- rhs , rhs_dtype = maybe_convert_dtype (rhs , rhs_aval , precision .rhs_precision_type )
3670- accumulation_type = precision .accumulation_type
3671- if accumulation_type is not None :
3672- accumulation_aval = core .ShapedArray (aval_out .shape , accumulation_type )
3699+ if target_dtype is None or operand_aval .dtype == target_dtype :
3700+ return operand
3701+ aval = core .ShapedArray (operand_aval .shape , target_dtype )
3702+ return mlir .convert_hlo (ctx , operand , operand_aval , aval )
3703+
3704+ lhs_dtype , rhs_dtype , accumulation_dtype = get_algorithm_compute_types (
3705+ precision , lhs_dtype , rhs_dtype , aval_out .dtype )
3706+ lhs = maybe_convert_dtype (lhs , lhs_aval , lhs_dtype )
3707+ rhs = maybe_convert_dtype (rhs , rhs_aval , rhs_dtype )
3708+ if accumulation_dtype is not None :
3709+ accumulation_aval = core .ShapedArray (aval_out .shape , accumulation_dtype )
36733710
36743711 if precision != DotAlgorithmPreset .DEFAULT :
36753712 algorithm_kwarg = {
@@ -3690,15 +3727,13 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
36903727 core .ShapedArray (lhs_aval .shape , aval_out .dtype ))
36913728 rhs = mlir .convert_hlo (ctx , rhs , rhs_aval ,
36923729 core .ShapedArray (rhs_aval .shape , aval_out .dtype ))
3693- lhs_dtype = rhs_dtype = aval_out .dtype
36943730 else : # cpu and gpu
36953731 # Do not convert mixed fp8 types to output type.
36963732 if not _is_fp8_mixed_precision_matmul (lhs_dtype , rhs_dtype ):
36973733 lhs = mlir .convert_hlo (ctx , lhs , lhs_aval ,
36983734 core .ShapedArray (lhs_aval .shape , aval_out .dtype ))
36993735 rhs = mlir .convert_hlo (ctx , rhs , rhs_aval ,
37003736 core .ShapedArray (rhs_aval .shape , aval_out .dtype ))
3701- lhs_dtype = rhs_dtype = aval_out .dtype
37023737
37033738 result = hlo .dot_general (
37043739 mlir .aval_to_ir_type (accumulation_aval ),
0 commit comments