@@ -923,15 +923,8 @@ def accumulation_type(self) -> DTypeLike | None:
923923 case _:
924924 return np .float32
925925
926- def supported_output_types (
927- self , lhs_dtype : DTypeLike , rhs_dtype : DTypeLike
928- ) -> tuple [DTypeLike , ...] | None :
929- if np .dtype (lhs_dtype ) != np .dtype (rhs_dtype ):
930- raise ValueError (
931- f"The dot algorithm '{ self } ' requires both inputs to have the same "
932- f'dtypes. Got { lhs_dtype } and { rhs_dtype } instead.'
933- )
934-
926+ @property
927+ def supported_output_types (self ) -> tuple [DTypeLike , ...] | None :
935928 match self :
936929 case (
937930 DotAlgorithmPreset .ANY_F8_ANY_F8_F32
@@ -949,11 +942,6 @@ def supported_output_types(
949942 )
950943 case DotAlgorithmPreset .F16_F16_F32 :
951944 return (np .float32 , np .float16 )
952- case DotAlgorithmPreset .BF16_BF16_F32 :
953- if np .dtype (lhs_dtype ) == dtypes .bfloat16 :
954- return (np .float32 , dtypes .bfloat16 )
955- else :
956- return (np .float32 ,)
957945 case _:
958946 accumulation_type = self .accumulation_type
959947 return None if accumulation_type is None else (accumulation_type ,)
@@ -3725,19 +3713,25 @@ def get_algorithm_compute_types(
37253713 algorithm .accumulation_type ,
37263714 )
37273715
3716+ supported_output_types = algorithm .supported_output_types
3717+
3718+ if algorithm == DotAlgorithmPreset .BF16_BF16_F32 :
3719+ # If dtype is anything other than float32, it will be cast to bfloat16.
3720+ if np .dtype (lhs_dtype ) != np .float32 :
3721+ supported_output_types = (np .float32 , dtypes .bfloat16 )
3722+
37283723 def maybe_convert_dtype (input_dtype , target_dtypes ):
37293724 if target_dtypes is None :
37303725 return input_dtype
37313726 if np .dtype (input_dtype ) in map (np .dtype , target_dtypes ):
37323727 return input_dtype
37333728 return target_dtypes [0 ]
37343729
3735- lhs_dtype = maybe_convert_dtype ( lhs_dtype , algorithm . supported_lhs_types )
3736- rhs_dtype = maybe_convert_dtype (rhs_dtype , algorithm .supported_rhs_types )
3737- out_type = maybe_convert_dtype (
3738- out_dtype , algorithm . supported_output_types ( lhs_dtype , rhs_dtype )
3730+ return (
3731+ maybe_convert_dtype (lhs_dtype , algorithm .supported_lhs_types ),
3732+ maybe_convert_dtype (rhs_dtype , algorithm . supported_rhs_types ),
3733+ maybe_convert_dtype ( out_dtype , supported_output_types ),
37393734 )
3740- return lhs_dtype , rhs_dtype , out_type
37413735
37423736
37433737def _dot_general_lower (ctx , lhs , rhs , * , dimension_numbers ,
0 commit comments