File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -185,16 +185,15 @@ def load_model_dict_into_meta(
185185    accepts_dtype  =  "dtype"  in  set (inspect .signature (set_module_tensor_to_device ).parameters .keys ())
186186    empty_state_dict  =  model .state_dict ()
187187    unexpected_keys  =  [param_name  for  param_name  in  state_dict  if  param_name  not  in empty_state_dict ]
188-     is_torch_e4m3fn_available  =  hasattr (torch , "float8_e4m3fn" )
189188
190189    for  param_name , param  in  state_dict .items ():
191190        if  param_name  not  in empty_state_dict :
192191            continue 
193192
194-         # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type . We also want to keep the buffers/params 
193+         # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params 
195194        # in int/uint/bool and not cast them. 
196-         is_param_float8_e4m3fn   =   is_torch_e4m3fn_available   and  param .dtype  ==  torch .float8_e4m3fn 
197-         if  torch .is_floating_point (param )  and   not   is_param_float8_e4m3fn :
195+         # TODO: revisit cases when  param.dtype == torch.float8_e4m3fn
196+         if  torch .is_floating_point (param ):
198197            if  (
199198                keep_in_fp32_modules  is  not None 
200199                and  any (
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments