@@ -879,11 +879,11 @@ def __str__(self) -> str:
879879 def lhs_precision_type (self ) -> DTypeLike | tuple [DTypeLike , ...] | None :
880880 match self :
881881 case (
882- DotAlgorithmPreset .DEFAULT |
883- DotAlgorithmPreset .ANY_F8_ANY_F8_F32 |
884- DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM |
885- DotAlgorithmPreset .ANY_F8_ANY_F8_ANY |
886- DotAlgorithmPreset .ANY_F8_ANY_F8_ANY_FAST_ACCUM
882+ DotAlgorithmPreset .DEFAULT
883+ | DotAlgorithmPreset .ANY_F8_ANY_F8_F32
884+ | DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM
885+ | DotAlgorithmPreset .ANY_F8_ANY_F8_ANY
886+ | DotAlgorithmPreset .ANY_F8_ANY_F8_ANY_FAST_ACCUM
887887 ):
888888 return None
889889 case DotAlgorithmPreset .F16_F16_F16 | DotAlgorithmPreset .F16_F16_F32 :
@@ -906,31 +906,38 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
906906 return self .lhs_precision_type
907907
908908 @property
909- def accumulation_type (self ) -> DTypeLike | tuple [ DTypeLike , ...] | None :
909+ def accumulation_type (self ) -> DTypeLike | None :
910910 match self :
911911 case (
912- DotAlgorithmPreset .DEFAULT |
913- DotAlgorithmPreset .ANY_F8_ANY_F8_ANY |
914- DotAlgorithmPreset .ANY_F8_ANY_F8_ANY_FAST_ACCUM
912+ DotAlgorithmPreset .DEFAULT
913+ | DotAlgorithmPreset .ANY_F8_ANY_F8_ANY
914+ | DotAlgorithmPreset .ANY_F8_ANY_F8_ANY_FAST_ACCUM
915915 ):
916916 return None
917+ case DotAlgorithmPreset .F16_F16_F16 :
918+ return np .float16
919+ case DotAlgorithmPreset .BF16_BF16_BF16 :
920+ return dtypes .bfloat16
921+ case DotAlgorithmPreset .F64_F64_F64 :
922+ return np .float64
923+ case _:
924+ return np .float32
925+
926+ @property
927+ def supported_output_types (self ) -> tuple [DTypeLike , ...] | None :
928+ match self :
917929 case (
918930 DotAlgorithmPreset .ANY_F8_ANY_F8_F32 |
919931 DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM
920932 ):
921933 return (np .float32 , np .float16 , dtypes .bfloat16 , dtypes .float8_e4m3fn ,
922934 dtypes .float8_e5m2 , dtypes .float8_e5m2fnuz ,
923935 dtypes .float8_e4m3fnuz , dtypes .float8_e4m3b11fnuz )
924- case DotAlgorithmPreset .F16_F16_F16 :
925- return np .float16
926936 case DotAlgorithmPreset .F16_F16_F32 :
927937 return (np .float32 , np .float16 )
928- case DotAlgorithmPreset .BF16_BF16_BF16 :
929- return dtypes .bfloat16
930- case DotAlgorithmPreset .F64_F64_F64 :
931- return np .float64
932938 case _:
933- return np .float32
939+ accumulation_type = self .accumulation_type
940+ return None if accumulation_type is None else (accumulation_type ,)
934941
935942 def _convert_to_hlo_attr (self , lhs_dtype : DTypeLike ,
936943 rhs_dtype : DTypeLike ) -> hlo .DotAlgorithm | None :
@@ -941,30 +948,39 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
941948 tf32 = ir .FloatTF32Type .get ()
942949 match self :
943950 case (
944- DotAlgorithmPreset .ANY_F8_ANY_F8_F32 |
945- DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM |
946- DotAlgorithmPreset .ANY_F8_ANY_F8_ANY |
947- DotAlgorithmPreset .ANY_F8_ANY_F8_ANY_FAST_ACCUM
951+ DotAlgorithmPreset .ANY_F8_ANY_F8_F32
952+ | DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM
953+ | DotAlgorithmPreset .ANY_F8_ANY_F8_ANY
954+ | DotAlgorithmPreset .ANY_F8_ANY_F8_ANY_FAST_ACCUM
948955 ):
949- fp8_dtypes = [np .dtype (dtypes .float8_e4m3b11fnuz ),
950- np .dtype (dtypes .float8_e4m3fn ),
951- np .dtype (dtypes .float8_e4m3fnuz ),
952- np .dtype (dtypes .float8_e5m2 ),
953- np .dtype (dtypes .float8_e5m2fnuz )]
956+ fp8_dtypes = [
957+ np .dtype (dtypes .float8_e4m3b11fnuz ),
958+ np .dtype (dtypes .float8_e4m3fn ),
959+ np .dtype (dtypes .float8_e4m3fnuz ),
960+ np .dtype (dtypes .float8_e5m2 ),
961+ np .dtype (dtypes .float8_e5m2fnuz ),
962+ ]
954963 if dtypes .float8_e3m4 is not None :
955964 fp8_dtypes += [np .dtype (dtypes .float8_e3m4 )]
956965 if dtypes .float8_e4m3 is not None :
957966 fp8_dtypes += [np .dtype (dtypes .float8_e4m3 )]
958967 if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes :
959968 raise ValueError (
960969 f"The dot algorithm '{ self } ' requires both inputs to have float8 "
961- f"dtypes. Got { lhs_dtype } and { rhs_dtype } instead." )
970+ f'dtypes. Got { lhs_dtype } and { rhs_dtype } instead.'
971+ )
962972 lhs = mlir .dtype_to_ir_type (dtypes .dtype (lhs_dtype ))
963973 rhs = mlir .dtype_to_ir_type (dtypes .dtype (rhs_dtype ))
964974 acc = ir .F32Type .get ()
965975 return hlo .DotAlgorithm .get (
966- lhs , rhs , acc , 1 , 1 , 1 ,
967- self == DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM )
976+ lhs ,
977+ rhs ,
978+ acc ,
979+ 1 ,
980+ 1 ,
981+ 1 ,
982+ self == DotAlgorithmPreset .ANY_F8_ANY_F8_F32_FAST_ACCUM ,
983+ )
968984 case DotAlgorithmPreset .F16_F16_F16 :
969985 return hlo .DotAlgorithm .get (f16 , f16 , f16 , 1 , 1 , 1 , False )
970986 case DotAlgorithmPreset .F16_F16_F32 :
@@ -3649,9 +3665,8 @@ def maybe_convert_dtype(input_dtype, target_dtype):
36493665 return input_dtype
36503666 if not isinstance (target_dtype , tuple ):
36513667 target_dtype = (target_dtype ,)
3652- if any (input_dtype == d for d in target_dtype ):
3653- return input_dtype
3654- return target_dtype [0 ]
3668+ return input_dtype if input_dtype in target_dtype else target_dtype [0 ]
3669+
36553670 if algorithm == DotAlgorithmPreset .BF16_BF16_F32 :
36563671 lhs_dtype = maybe_convert_dtype (lhs_dtype , algorithm .lhs_precision_type )
36573672 rhs_dtype = maybe_convert_dtype (rhs_dtype , algorithm .rhs_precision_type )
@@ -3662,10 +3677,15 @@ def maybe_convert_dtype(input_dtype, target_dtype):
36623677 out_dtype = maybe_convert_dtype (out_dtype , np .float32 )
36633678 return lhs_dtype , rhs_dtype , out_dtype
36643679 else :
3680+ if isinstance (algorithm , DotAlgorithmPreset ):
3681+ supported_output_types = algorithm .supported_output_types
3682+ else :
3683+ supported_output_types = (algorithm .accumulation_type ,)
3684+
36653685 return (
36663686 maybe_convert_dtype (lhs_dtype , algorithm .lhs_precision_type ),
36673687 maybe_convert_dtype (rhs_dtype , algorithm .rhs_precision_type ),
3668- maybe_convert_dtype (out_dtype , algorithm . accumulation_type ),
3688+ maybe_convert_dtype (out_dtype , supported_output_types ),
36693689 )
36703690
36713691
0 commit comments