@@ -1143,8 +1143,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11431143
11441144        # set dtype to instantiate the model under: 
11451145        # 1. If torch_dtype is not None, we use that dtype 
1146+         # 2. If torch_dtype is float8, we don't use _set_default_torch_dtype and we downcast after loading the model 
11461147        dtype_orig  =  None 
1147-         if  torch_dtype  is  not None :
1148+         if  torch_dtype  is  not None   and   not   torch_dtype   ==   getattr ( torch ,  "float8_e4m3fn" ,  None ) :
11481149            if  not  isinstance (torch_dtype , torch .dtype ):
11491150                raise  ValueError (
11501151                    f"{ torch_dtype } { type (torch_dtype )}  
@@ -1231,6 +1232,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12311232            hf_quantizer .postprocess_model (model )
12321233            model .hf_quantizer  =  hf_quantizer 
12331234
1235+         if  (
1236+             torch_dtype  is  not None 
1237+             and  torch_dtype  ==  getattr (torch , "float8_e4m3fn" , None )
1238+             and  hf_quantizer  is  None 
1239+             and  not  use_keep_in_fp32_modules 
1240+         ):
1241+             model  =  model .to (torch_dtype )
1242+ 
12341243        if  hf_quantizer  is  not None :
12351244            # We also make sure to purge `_pre_quantization_dtype` when we serialize 
12361245            # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. 
0 commit comments