File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -3704,12 +3704,14 @@ def maybe_convert_dtype(input_dtype, target_dtype):
37043704 return input_dtype
37053705 if not isinstance (target_dtype , tuple ):
37063706 target_dtype = (target_dtype ,)
3707- return input_dtype if input_dtype in target_dtype else target_dtype [0 ]
3707+ if np .dtype (input_dtype ) in map (np .dtype , target_dtype ):
3708+ return input_dtype
3709+ return target_dtype [0 ]
37083710
37093711 if algorithm == DotAlgorithmPreset .BF16_BF16_F32 :
37103712 lhs_dtype = maybe_convert_dtype (lhs_dtype , algorithm .lhs_precision_type )
37113713 rhs_dtype = maybe_convert_dtype (rhs_dtype , algorithm .rhs_precision_type )
3712- if lhs_dtype == dtypes .bfloat16 :
3714+ if np . dtype ( lhs_dtype ) == dtypes .bfloat16 :
37133715 out_dtype = maybe_convert_dtype (out_dtype ,
37143716 (np .float32 , dtypes .bfloat16 ))
37153717 else :
You can’t perform that action at this time.
0 commit comments