diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 16bd0441072a..b53647d47630 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -22,6 +22,7 @@ from typing_extensions import Self from .. import __version__ +from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map from ..quantizers import DiffusersAutoQuantizer from ..utils import deprecate, is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache @@ -297,6 +298,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) device = kwargs.pop("device", None) disable_mmap = kwargs.pop("disable_mmap", False) + device_map = kwargs.pop("device_map", None) user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"} # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` @@ -403,19 +405,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = with ctx(): model = cls.from_config(diffusers_model_config) - checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) - - if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): - diffusers_format_checkpoint = checkpoint_mapping_fn( - config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs - ) - else: - diffusers_format_checkpoint = checkpoint + model_state_dict = model.state_dict() - if not diffusers_format_checkpoint: - raise SingleFileComponentError( - f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." - ) # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") @@ -428,6 +419,26 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = else: keep_in_fp32_modules = [] + # Now that the model is loaded, we can determine the `device_map` + device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer) + if device_map is not None: + expanded_device_map = _expand_device_map(device_map, model_state_dict.keys()) + _caching_allocator_warmup(model, expanded_device_map, torch_dtype, hf_quantizer) + + checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) + + if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint): + diffusers_format_checkpoint = checkpoint_mapping_fn( + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs + ) + else: + diffusers_format_checkpoint = checkpoint + + if not diffusers_format_checkpoint: + raise SingleFileComponentError( + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." + ) + if hf_quantizer is not None: hf_quantizer.preprocess_model( model=model, diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py index a7e07e517f4d..8290c339b931 100644 --- a/tests/single_file/test_model_flux_transformer_single_file.py +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -69,3 +69,11 @@ def test_checkpoint_loading(self): del model gc.collect() backend_empty_cache(torch_device) + + def test_device_map_cuda(self): + backend_empty_cache(torch_device) + model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda") + + del model + gc.collect() + backend_empty_cache(torch_device)