@@ -624,10 +624,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
624624                # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. 
625625                raise  ValueError ("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10." )
626626
627-         if  (low_cpu_mem_usage  is  None  or  not  low_cpu_mem_usage ) and  cls ._keep_in_fp32_modules  is  not None :
628-             low_cpu_mem_usage  =  True 
629-             logger .info ("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None." )
630- 
631627        # Load config if we don't provide a configuration 
632628        config_path  =  pretrained_model_name_or_path 
633629
@@ -683,7 +679,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
683679            user_agent ["quant" ] =  hf_quantizer .quantization_config .quant_method .value 
684680
685681            # Force-set to `True` for more mem efficiency 
686-             if  not  low_cpu_mem_usage :
682+             if  low_cpu_mem_usage  is  None :
683+                 low_cpu_mem_usage  =  True 
684+                 logger .info ("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None." )
685+             elif  not  low_cpu_mem_usage :
687686                raise  ValueError ("`low_cpu_mem_usage` cannot be False or None when using quantization." )
688687
689688        # Check if `_keep_in_fp32_modules` is not None 
@@ -694,6 +693,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
694693            keep_in_fp32_modules  =  cls ._keep_in_fp32_modules 
695694            if  not  isinstance (keep_in_fp32_modules , list ):
696695                keep_in_fp32_modules  =  [keep_in_fp32_modules ]
696+ 
697+             if  low_cpu_mem_usage  is  None :
698+                 low_cpu_mem_usage  =  True 
699+                 logger .info ("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None." )
700+             elif  not  low_cpu_mem_usage :
701+                 raise  ValueError ("`low_cpu_mem_usage` cannot be False or None when using quantization." )
697702        else :
698703            keep_in_fp32_modules  =  []
699704        ####################################### 
0 commit comments