2222from typing_extensions import Self
2323
2424from .. import __version__
25+ from ..models .model_loading_utils import _caching_allocator_warmup , _determine_device_map , _expand_device_map
2526from ..quantizers import DiffusersAutoQuantizer
2627from ..utils import deprecate , is_accelerate_available , is_torch_version , logging
2728from ..utils .torch_utils import empty_device_cache
@@ -297,6 +298,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
297298 low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT )
298299 device = kwargs .pop ("device" , None )
299300 disable_mmap = kwargs .pop ("disable_mmap" , False )
301+ device_map = kwargs .pop ("device_map" , None )
300302
301303 user_agent = {"diffusers" : __version__ , "file_type" : "single_file" , "framework" : "pytorch" }
302304 # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
@@ -403,19 +405,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
403405 with ctx ():
404406 model = cls .from_config (diffusers_model_config )
405407
406- checkpoint_mapping_kwargs = _get_mapping_function_kwargs (checkpoint_mapping_fn , ** kwargs )
407-
408- if _should_convert_state_dict_to_diffusers (model .state_dict (), checkpoint ):
409- diffusers_format_checkpoint = checkpoint_mapping_fn (
410- config = diffusers_model_config , checkpoint = checkpoint , ** checkpoint_mapping_kwargs
411- )
412- else :
413- diffusers_format_checkpoint = checkpoint
408+ model_state_dict = model .state_dict ()
414409
415- if not diffusers_format_checkpoint :
416- raise SingleFileComponentError (
417- f"Failed to load { mapping_class_name } . Weights for this component appear to be missing in the checkpoint."
418- )
419410 # Check if `_keep_in_fp32_modules` is not None
420411 use_keep_in_fp32_modules = (cls ._keep_in_fp32_modules is not None ) and (
421412 (torch_dtype == torch .float16 ) or hasattr (hf_quantizer , "use_keep_in_fp32_modules" )
@@ -428,6 +419,26 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
428419 else :
429420 keep_in_fp32_modules = []
430421
422+ # Now that the model is loaded, we can determine the `device_map`
423+ device_map = _determine_device_map (model , device_map , None , torch_dtype , keep_in_fp32_modules , hf_quantizer )
424+ if device_map is not None :
425+ expanded_device_map = _expand_device_map (device_map , model_state_dict .keys ())
426+ _caching_allocator_warmup (model , expanded_device_map , torch_dtype , hf_quantizer )
427+
428+ checkpoint_mapping_kwargs = _get_mapping_function_kwargs (checkpoint_mapping_fn , ** kwargs )
429+
430+ if _should_convert_state_dict_to_diffusers (model_state_dict , checkpoint ):
431+ diffusers_format_checkpoint = checkpoint_mapping_fn (
432+ config = diffusers_model_config , checkpoint = checkpoint , ** checkpoint_mapping_kwargs
433+ )
434+ else :
435+ diffusers_format_checkpoint = checkpoint
436+
437+ if not diffusers_format_checkpoint :
438+ raise SingleFileComponentError (
439+ f"Failed to load { mapping_class_name } . Weights for this component appear to be missing in the checkpoint."
440+ )
441+
431442 if hf_quantizer is not None :
432443 hf_quantizer .preprocess_model (
433444 model = model ,
0 commit comments