Skip to content

Commit af3ecea

Browse files
committed
low_cpu_mem_usage shenanigans when using fp32 modules.
1 parent c78dd0c commit af3ecea

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -624,10 +624,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
624624
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
625625
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
626626

627-
if (low_cpu_mem_usage is None or not low_cpu_mem_usage) and cls._keep_in_fp32_modules is not None:
628-
low_cpu_mem_usage = True
629-
logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
630-
631627
# Load config if we don't provide a configuration
632628
config_path = pretrained_model_name_or_path
633629

@@ -683,7 +679,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
683679
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
684680

685681
# Force-set to `True` for more mem efficiency
686-
if not low_cpu_mem_usage:
682+
if low_cpu_mem_usage is None:
683+
low_cpu_mem_usage = True
684+
logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
685+
elif not low_cpu_mem_usage:
687686
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
688687

689688
# Check if `_keep_in_fp32_modules` is not None
@@ -694,6 +693,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
694693
keep_in_fp32_modules = cls._keep_in_fp32_modules
695694
if not isinstance(keep_in_fp32_modules, list):
696695
keep_in_fp32_modules = [keep_in_fp32_modules]
696+
697+
if low_cpu_mem_usage is None:
698+
low_cpu_mem_usage = True
699+
logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
700+
elif not low_cpu_mem_usage:
701+
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
697702
else:
698703
keep_in_fp32_modules = []
699704
#######################################

0 commit comments

Comments
 (0)