@@ -923,8 +923,9 @@ def accumulation_type(self) -> DTypeLike | None:
923923 case _:
924924 return np .float32
925925
926- @property
927- def supported_output_types (self ) -> tuple [DTypeLike , ...] | None :
926+ def supported_output_types (
927+ self , lhs_dtype : DTypeLike , rhs_dtype : DTypeLike
928+ ) -> tuple [DTypeLike , ...] | None :
928929 match self :
929930 case (
930931 DotAlgorithmPreset .ANY_F8_ANY_F8_F32
@@ -941,7 +942,17 @@ def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
941942 dtypes .float8_e4m3b11fnuz ,
942943 )
943944 case DotAlgorithmPreset .F16_F16_F32 :
944- return (np .float32 , np .float16 )
945+ # F16 output is only supported with F16 inputs.
946+ if dtypes .promote_types (lhs_dtype , rhs_dtype ) == np .float16 :
947+ return (np .float32 , np .float16 )
948+ else :
949+ return (np .float32 ,)
950+ case DotAlgorithmPreset .BF16_BF16_F32 :
951+ # BF16 output is only supported with BF16 inputs.
952+ if dtypes .promote_types (lhs_dtype , rhs_dtype ) == dtypes .bfloat16 :
953+ return (np .float32 , dtypes .bfloat16 )
954+ else :
955+ return (np .float32 ,)
945956 case _:
946957 accumulation_type = self .accumulation_type
947958 return None if accumulation_type is None else (accumulation_type ,)
@@ -3713,25 +3724,19 @@ def get_algorithm_compute_types(
37133724 algorithm .accumulation_type ,
37143725 )
37153726
3716- supported_output_types = algorithm .supported_output_types
3717-
3718- if algorithm == DotAlgorithmPreset .BF16_BF16_F32 :
3719- # If dtype is anything other than float32, it will be cast to bfloat16.
3720- if np .dtype (lhs_dtype ) != np .float32 :
3721- supported_output_types = (np .float32 , dtypes .bfloat16 )
3722-
37233727 def maybe_convert_dtype (input_dtype , target_dtypes ):
37243728 if target_dtypes is None :
37253729 return input_dtype
37263730 if np .dtype (input_dtype ) in map (np .dtype , target_dtypes ):
37273731 return input_dtype
37283732 return target_dtypes [0 ]
37293733
3730- return (
3731- maybe_convert_dtype (lhs_dtype , algorithm .supported_lhs_types ),
3732- maybe_convert_dtype (rhs_dtype , algorithm . supported_rhs_types ),
3733- maybe_convert_dtype ( out_dtype , supported_output_types ),
3734+ lhs_dtype = maybe_convert_dtype ( lhs_dtype , algorithm . supported_lhs_types )
3735+ rhs_dtype = maybe_convert_dtype (rhs_dtype , algorithm .supported_rhs_types )
3736+ out_type = maybe_convert_dtype (
3737+ out_dtype , algorithm . supported_output_types ( lhs_dtype , rhs_dtype )
37343738 )
3739+ return lhs_dtype , rhs_dtype , out_type
37353740
37363741
37373742def _dot_general_lower (ctx , lhs , rhs , * , dimension_numbers ,
0 commit comments