Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing_extensions import Self

from .. import __version__
from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
Expand Down Expand Up @@ -297,6 +298,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
device_map = kwargs.pop("device_map", None)

user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
Expand Down Expand Up @@ -403,19 +405,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
with ctx():
model = cls.from_config(diffusers_model_config)

checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)

if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint
model_state_dict = model.state_dict()

if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
Expand All @@ -428,6 +419,26 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
else:
keep_in_fp32_modules = []

# Now that the model is loaded, we can determine the `device_map`
device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer)
if device_map is not None:
expanded_device_map = _expand_device_map(device_map, model_state_dict.keys())
_caching_allocator_warmup(model, expanded_device_map, torch_dtype, hf_quantizer)

checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)

if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint

if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)

if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
Expand Down
8 changes: 8 additions & 0 deletions tests/single_file/test_model_flux_transformer_single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,11 @@ def test_checkpoint_loading(self):
del model
gc.collect()
backend_empty_cache(torch_device)

def test_device_map_cuda(self):
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")

del model
gc.collect()
backend_empty_cache(torch_device)
Loading