@@ -706,16 +706,6 @@ def triton_compute_type(dtype: torch.dtype) -> str:
706706 return triton_type (upcast_compute_type (dtype ))
707707
708708
709- def _get_primitive_bitwidth (dtype : torch .dtype ) -> int :
710- """Number of bits of triton_compute_type()"""
711- dtype = upcast_compute_type (dtype )
712- itemsize = getattr (dtype , "itemsize" , None )
713- if itemsize :
714- return itemsize * 8
715- else :
716- return - 1
717-
718-
719709def triton_store_type (dtype : torch .dtype ) -> str :
720710 """Convert torch.dtype to triton type, with fix for storing tl.bool"""
721711 if dtype == torch .bool :
@@ -887,30 +877,20 @@ def _get_min_elements_per_thread(
887877
888878 @staticmethod
889879 def to_dtype_bitcast (x , dtype : torch .dtype , src_dtype : torch .dtype ):
890- triton_dtype = triton_compute_type ( dtype )
880+ assert src_dtype . itemsize == dtype . itemsize
891881 # We may promote float16 or bfloat16 to float32 and cause the
892882 # bitwidth of dtype to be different from the input tensor (i.e. float32).
893883 # In such as case, we will have to convert the input tensor to
894884 # its src_type, perform bitcast, and then convert the bit-casted
895885 # tensor back to float to ensure we use values with the right precision.
896- if (
897- src_dtype in (torch .float16 , torch .bfloat16 )
898- and config .triton .codegen_upcast_to_fp32
899- ):
900- triton_src_dtype = str (src_dtype ).split ("." )[- 1 ]
901- cast_x = f"{ x } .to(tl.{ triton_src_dtype } )"
902- if dtype in (torch .float16 , torch .bfloat16 ):
903- triton_type_name = str (dtype ).split ("." )[- 1 ]
904- triton_dtype = f"tl.{ triton_type_name } "
905- cast_x = f"{ cast_x } .to({ triton_dtype } , bitcast=True)"
906- if dtype in (torch .float16 , torch .bfloat16 ):
907- return f"{ cast_x } .to(tl.float32)"
908- return cast_x
909- else :
910- src_dtype_bitwidth = _get_primitive_bitwidth (src_dtype )
911- target_dtype_bitwidth = _get_primitive_bitwidth (dtype )
912- bitcast = "True" if src_dtype_bitwidth == target_dtype_bitwidth else "False"
913- return f"{ x } .to({ triton_dtype } , bitcast={ bitcast } )"
886+ if x .dtype != src_dtype :
887+ x = f"{ x } .to({ triton_type (src_dtype )} )"
888+
889+ out = f"{ x } .to({ triton_type (dtype )} , bitcast=True)"
890+ if upcast_compute_type (dtype ) != dtype :
891+ out = f"{ out } .to({ triton_type (upcast_compute_type (dtype ))} )"
892+
893+ return out
914894
915895 @staticmethod
916896 def _shaped_constant (value , dtype , shape ):
0 commit comments