1313# limitations under the License.
1414import importlib
1515import inspect
16- import re
17- from contextlib import nullcontext
1816from typing import Optional
1917
20- import torch
2118from huggingface_hub .utils import validate_hf_hub_args
2219
23- from ..quantizers import DiffusersAutoQuantizer
24- from ..utils import deprecate , is_accelerate_available , logging
20+ from ..utils import deprecate , logging
2521from .single_file_utils import (
2622 SingleFileComponentError ,
2723 convert_animatediff_checkpoint_to_diffusers ,
4945logger = logging .get_logger (__name__ )
5046
5147
52- if is_accelerate_available ():
53- from accelerate import init_empty_weights
54-
55- from ..models .modeling_utils import load_model_dict_into_meta
56-
57-
5848SINGLE_FILE_LOADABLE_CLASSES = {
5949 "StableCascadeUNet" : {
6050 "checkpoint_mapping_fn" : convert_stable_cascade_unet_single_file_to_diffusers ,
@@ -234,9 +224,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
234224 subfolder = kwargs .pop ("subfolder" , None )
235225 revision = kwargs .pop ("revision" , None )
236226 config_revision = kwargs .pop ("config_revision" , None )
237- torch_dtype = kwargs .pop ("torch_dtype" , None )
238- quantization_config = kwargs .pop ("quantization_config" , None )
239- device = kwargs .pop ("device" , None )
240227 disable_mmap = kwargs .pop ("disable_mmap" , False )
241228
242229 if isinstance (pretrained_model_link_or_path_or_dict , dict ):
@@ -252,12 +239,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
252239 revision = revision ,
253240 disable_mmap = disable_mmap ,
254241 )
255- if quantization_config is not None :
256- hf_quantizer = DiffusersAutoQuantizer .from_config (quantization_config )
257- hf_quantizer .validate_environment ()
258-
259- else :
260- hf_quantizer = None
261242
262243 mapping_functions = SINGLE_FILE_LOADABLE_CLASSES [mapping_class_name ]
263244
@@ -336,62 +317,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
336317 f"Failed to load { mapping_class_name } . Weights for this component appear to be missing in the checkpoint."
337318 )
338319
339- ctx = init_empty_weights if is_accelerate_available () else nullcontext
340- with ctx ():
341- model = cls .from_config (diffusers_model_config )
342-
343- # Check if `_keep_in_fp32_modules` is not None
344- use_keep_in_fp32_modules = (cls ._keep_in_fp32_modules is not None ) and (
345- (torch_dtype == torch .float16 ) or hasattr (hf_quantizer , "use_keep_in_fp32_modules" )
320+ return cls .from_pretrained (
321+ pretrained_model_name_or_path = None ,
322+ state_dict = diffusers_format_checkpoint ,
323+ config = diffusers_model_config ,
324+ ** kwargs ,
346325 )
347- if use_keep_in_fp32_modules :
348- keep_in_fp32_modules = cls ._keep_in_fp32_modules
349- if not isinstance (keep_in_fp32_modules , list ):
350- keep_in_fp32_modules = [keep_in_fp32_modules ]
351-
352- else :
353- keep_in_fp32_modules = []
354-
355- if hf_quantizer is not None :
356- hf_quantizer .preprocess_model (
357- model = model ,
358- device_map = None ,
359- state_dict = diffusers_format_checkpoint ,
360- keep_in_fp32_modules = keep_in_fp32_modules ,
361- )
362-
363- if is_accelerate_available ():
364- param_device = torch .device (device ) if device else torch .device ("cpu" )
365- named_buffers = model .named_buffers ()
366- unexpected_keys = load_model_dict_into_meta (
367- model ,
368- diffusers_format_checkpoint ,
369- dtype = torch_dtype ,
370- device = param_device ,
371- hf_quantizer = hf_quantizer ,
372- keep_in_fp32_modules = keep_in_fp32_modules ,
373- named_buffers = named_buffers ,
374- )
375-
376- else :
377- _ , unexpected_keys = model .load_state_dict (diffusers_format_checkpoint , strict = False )
378-
379- if model ._keys_to_ignore_on_load_unexpected is not None :
380- for pat in model ._keys_to_ignore_on_load_unexpected :
381- unexpected_keys = [k for k in unexpected_keys if re .search (pat , k ) is None ]
382-
383- if len (unexpected_keys ) > 0 :
384- logger .warning (
385- f"Some weights of the model checkpoint were not used when initializing { cls .__name__ } : \n { [', ' .join (unexpected_keys )]} "
386- )
387-
388- if hf_quantizer is not None :
389- hf_quantizer .postprocess_model (model )
390- model .hf_quantizer = hf_quantizer
391-
392- if torch_dtype is not None and hf_quantizer is None :
393- model .to (torch_dtype )
394-
395- model .eval ()
396-
397- return model
0 commit comments