Skip to content

Commit d3a7dc8

Browse files
committed
revert logic for single file
1 parent 0df7010 commit d3a7dc8

File tree

2 files changed

+188
-122
lines changed

2 files changed

+188
-122
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@
1313
# limitations under the License.
1414
import importlib
1515
import inspect
16+
import re
17+
from contextlib import nullcontext
1618
from typing import Optional
1719

20+
import torch
1821
from huggingface_hub.utils import validate_hf_hub_args
1922

20-
from ..utils import deprecate, logging
23+
from ..quantizers import DiffusersAutoQuantizer
24+
from ..utils import deprecate, is_accelerate_available, logging
2125
from .single_file_utils import (
2226
SingleFileComponentError,
2327
convert_animatediff_checkpoint_to_diffusers,
@@ -45,6 +49,12 @@
4549
logger = logging.get_logger(__name__)
4650

4751

52+
if is_accelerate_available():
53+
from accelerate import init_empty_weights
54+
55+
from ..models.modeling_utils import load_model_dict_into_meta
56+
57+
4858
SINGLE_FILE_LOADABLE_CLASSES = {
4959
"StableCascadeUNet": {
5060
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
@@ -224,6 +234,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
224234
subfolder = kwargs.pop("subfolder", None)
225235
revision = kwargs.pop("revision", None)
226236
config_revision = kwargs.pop("config_revision", None)
237+
torch_dtype = kwargs.pop("torch_dtype", None)
238+
quantization_config = kwargs.pop("quantization_config", None)
239+
device = kwargs.pop("device", None)
227240
disable_mmap = kwargs.pop("disable_mmap", False)
228241

229242
if isinstance(pretrained_model_link_or_path_or_dict, dict):
@@ -239,6 +252,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
239252
revision=revision,
240253
disable_mmap=disable_mmap,
241254
)
255+
if quantization_config is not None:
256+
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
257+
hf_quantizer.validate_environment()
258+
259+
else:
260+
hf_quantizer = None
242261

243262
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
244263

@@ -317,9 +336,61 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
317336
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
318337
)
319338

320-
return cls.from_pretrained(
321-
pretrained_model_name_or_path=None,
322-
state_dict=diffusers_format_checkpoint,
323-
config=diffusers_model_config,
324-
**kwargs,
339+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
340+
with ctx():
341+
model = cls.from_config(diffusers_model_config)
342+
343+
# Check if `_keep_in_fp32_modules` is not None
344+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
345+
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
325346
)
347+
if use_keep_in_fp32_modules:
348+
keep_in_fp32_modules = cls._keep_in_fp32_modules
349+
if not isinstance(keep_in_fp32_modules, list):
350+
keep_in_fp32_modules = [keep_in_fp32_modules]
351+
352+
else:
353+
keep_in_fp32_modules = []
354+
355+
if hf_quantizer is not None:
356+
hf_quantizer.preprocess_model(
357+
model=model,
358+
device_map=None,
359+
state_dict=diffusers_format_checkpoint,
360+
keep_in_fp32_modules=keep_in_fp32_modules,
361+
)
362+
363+
if is_accelerate_available():
364+
param_device = torch.device(device) if device else torch.device("cpu")
365+
unexpected_keys = [param_name for param_name in diffusers_format_checkpoint if param_name not in model.state_dict()]
366+
load_model_dict_into_meta(
367+
model,
368+
diffusers_format_checkpoint,
369+
dtype=torch_dtype,
370+
device_map={"":param_device},
371+
hf_quantizer=hf_quantizer,
372+
keep_in_fp32_modules=keep_in_fp32_modules,
373+
unexpected_keys=unexpected_keys,
374+
)
375+
else:
376+
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
377+
378+
if model._keys_to_ignore_on_load_unexpected is not None:
379+
for pat in model._keys_to_ignore_on_load_unexpected:
380+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
381+
382+
if len(unexpected_keys) > 0:
383+
logger.warning(
384+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
385+
)
386+
387+
if hf_quantizer is not None:
388+
hf_quantizer.postprocess_model(model)
389+
model.hf_quantizer = hf_quantizer
390+
391+
if torch_dtype is not None and hf_quantizer is None:
392+
model.to(torch_dtype)
393+
394+
model.eval()
395+
396+
return model

src/diffusers/models/modeling_utils.py

Lines changed: 111 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -795,8 +795,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
795795
quantization_config = kwargs.pop("quantization_config", None)
796796
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
797797
disable_mmap = kwargs.pop("disable_mmap", False)
798-
state_dict = kwargs.pop("state_dict", None)
799-
config = kwargs.pop("config", None)
800798

801799
allow_pickle = False
802800
if use_safetensors is None:
@@ -867,39 +865,35 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
867865
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
868866
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
869867

870-
if (not config and state_dict) or (config and not state_dict):
871-
raise ValueError("You need to pass both the config and the state dict to initalize the model.")
872-
873868
user_agent = {
874869
"diffusers": __version__,
875870
"file_type": "model",
876871
"framework": "pytorch",
877872
}
878873
unused_kwargs = {}
879874

880-
if config is None:
881-
# Load config if we don't provide a configuration
882-
config_path = pretrained_model_name_or_path
875+
# Load config if we don't provide a configuration
876+
config_path = pretrained_model_name_or_path
883877

884-
# TODO: We need to let the user pass a config in from_pretrained
885-
# load config
886-
config, unused_kwargs, commit_hash = cls.load_config(
887-
config_path,
888-
cache_dir=cache_dir,
889-
return_unused_kwargs=True,
890-
return_commit_hash=True,
891-
force_download=force_download,
892-
proxies=proxies,
893-
local_files_only=local_files_only,
894-
token=token,
895-
revision=revision,
896-
subfolder=subfolder,
897-
user_agent=user_agent,
898-
dduf_entries=dduf_entries,
899-
**kwargs,
900-
)
901-
# no in-place modification of the original config.
902-
config = copy.deepcopy(config)
878+
# TODO: We need to let the user pass a config in from_pretrained
879+
# load config
880+
config, unused_kwargs, commit_hash = cls.load_config(
881+
config_path,
882+
cache_dir=cache_dir,
883+
return_unused_kwargs=True,
884+
return_commit_hash=True,
885+
force_download=force_download,
886+
proxies=proxies,
887+
local_files_only=local_files_only,
888+
token=token,
889+
revision=revision,
890+
subfolder=subfolder,
891+
user_agent=user_agent,
892+
dduf_entries=dduf_entries,
893+
**kwargs,
894+
)
895+
# no in-place modification of the original config.
896+
config = copy.deepcopy(config)
903897

904898
# determine initial quantization config.
905899
#######################################
@@ -951,103 +945,79 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
951945

952946
is_sharded = False
953947
resolved_archive_file = None
954-
if state_dict is None:
955-
# Determine if we're loading from a directory of sharded checkpoints.
956-
sharded_metadata = None
957-
index_file = None
958-
is_local = os.path.isdir(pretrained_model_name_or_path)
959-
index_file_kwargs = {
960-
"is_local": is_local,
961-
"pretrained_model_name_or_path": pretrained_model_name_or_path,
962-
"subfolder": subfolder or "",
963-
"use_safetensors": use_safetensors,
964-
"cache_dir": cache_dir,
965-
"variant": variant,
966-
"force_download": force_download,
967-
"proxies": proxies,
968-
"local_files_only": local_files_only,
969-
"token": token,
970-
"revision": revision,
971-
"user_agent": user_agent,
972-
"commit_hash": commit_hash,
973-
"dduf_entries": dduf_entries,
974-
}
975-
index_file = _fetch_index_file(**index_file_kwargs)
976-
# In case the index file was not found we still have to consider the legacy format.
977-
# this becomes applicable when the variant is not None.
978-
if variant is not None and (index_file is None or not os.path.exists(index_file)):
979-
index_file = _fetch_index_file_legacy(**index_file_kwargs)
980-
if index_file is not None and (dduf_entries or index_file.is_file()):
981-
is_sharded = True
982-
983-
if is_sharded and from_flax:
984-
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
985-
986-
# load model
987-
if from_flax:
988-
resolved_archive_file = _get_model_file(
948+
949+
# Determine if we're loading from a directory of sharded checkpoints.
950+
sharded_metadata = None
951+
index_file = None
952+
is_local = os.path.isdir(pretrained_model_name_or_path)
953+
index_file_kwargs = {
954+
"is_local": is_local,
955+
"pretrained_model_name_or_path": pretrained_model_name_or_path,
956+
"subfolder": subfolder or "",
957+
"use_safetensors": use_safetensors,
958+
"cache_dir": cache_dir,
959+
"variant": variant,
960+
"force_download": force_download,
961+
"proxies": proxies,
962+
"local_files_only": local_files_only,
963+
"token": token,
964+
"revision": revision,
965+
"user_agent": user_agent,
966+
"commit_hash": commit_hash,
967+
"dduf_entries": dduf_entries,
968+
}
969+
index_file = _fetch_index_file(**index_file_kwargs)
970+
# In case the index file was not found we still have to consider the legacy format.
971+
# this becomes applicable when the variant is not None.
972+
if variant is not None and (index_file is None or not os.path.exists(index_file)):
973+
index_file = _fetch_index_file_legacy(**index_file_kwargs)
974+
if index_file is not None and (dduf_entries or index_file.is_file()):
975+
is_sharded = True
976+
977+
if is_sharded and from_flax:
978+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
979+
980+
# load model
981+
if from_flax:
982+
resolved_archive_file = _get_model_file(
983+
pretrained_model_name_or_path,
984+
weights_name=FLAX_WEIGHTS_NAME,
985+
cache_dir=cache_dir,
986+
force_download=force_download,
987+
proxies=proxies,
988+
local_files_only=local_files_only,
989+
token=token,
990+
revision=revision,
991+
subfolder=subfolder,
992+
user_agent=user_agent,
993+
commit_hash=commit_hash,
994+
)
995+
model = cls.from_config(config, **unused_kwargs)
996+
997+
# Convert the weights
998+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
999+
1000+
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
1001+
else:
1002+
# in the case it is sharded, we have already the index
1003+
if is_sharded:
1004+
resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files(
9891005
pretrained_model_name_or_path,
990-
weights_name=FLAX_WEIGHTS_NAME,
1006+
index_file,
9911007
cache_dir=cache_dir,
992-
force_download=force_download,
9931008
proxies=proxies,
9941009
local_files_only=local_files_only,
9951010
token=token,
996-
revision=revision,
997-
subfolder=subfolder,
9981011
user_agent=user_agent,
999-
commit_hash=commit_hash,
1012+
revision=revision,
1013+
subfolder=subfolder or "",
1014+
dduf_entries=dduf_entries,
10001015
)
1001-
model = cls.from_config(config, **unused_kwargs)
1002-
1003-
# Convert the weights
1004-
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
1005-
1006-
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
1007-
else:
1008-
# in the case it is sharded, we have already the index
1009-
if is_sharded:
1010-
resolved_archive_file, sharded_metadata = _get_checkpoint_shard_files(
1011-
pretrained_model_name_or_path,
1012-
index_file,
1013-
cache_dir=cache_dir,
1014-
proxies=proxies,
1015-
local_files_only=local_files_only,
1016-
token=token,
1017-
user_agent=user_agent,
1018-
revision=revision,
1019-
subfolder=subfolder or "",
1020-
dduf_entries=dduf_entries,
1021-
)
1022-
elif use_safetensors:
1023-
try:
1024-
resolved_archive_file = _get_model_file(
1025-
pretrained_model_name_or_path,
1026-
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1027-
cache_dir=cache_dir,
1028-
force_download=force_download,
1029-
proxies=proxies,
1030-
local_files_only=local_files_only,
1031-
token=token,
1032-
revision=revision,
1033-
subfolder=subfolder,
1034-
user_agent=user_agent,
1035-
commit_hash=commit_hash,
1036-
dduf_entries=dduf_entries,
1037-
)
1038-
1039-
except IOError as e:
1040-
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
1041-
if not allow_pickle:
1042-
raise
1043-
logger.warning(
1044-
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
1045-
)
1046-
1047-
if resolved_archive_file is None and not is_sharded:
1016+
elif use_safetensors:
1017+
try:
10481018
resolved_archive_file = _get_model_file(
10491019
pretrained_model_name_or_path,
1050-
weights_name=_add_variant(WEIGHTS_NAME, variant),
1020+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
10511021
cache_dir=cache_dir,
10521022
force_download=force_download,
10531023
proxies=proxies,
@@ -1060,6 +1030,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10601030
dduf_entries=dduf_entries,
10611031
)
10621032

1033+
except IOError as e:
1034+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
1035+
if not allow_pickle:
1036+
raise
1037+
logger.warning(
1038+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
1039+
)
1040+
1041+
if resolved_archive_file is None and not is_sharded:
1042+
resolved_archive_file = _get_model_file(
1043+
pretrained_model_name_or_path,
1044+
weights_name=_add_variant(WEIGHTS_NAME, variant),
1045+
cache_dir=cache_dir,
1046+
force_download=force_download,
1047+
proxies=proxies,
1048+
local_files_only=local_files_only,
1049+
token=token,
1050+
revision=revision,
1051+
subfolder=subfolder,
1052+
user_agent=user_agent,
1053+
commit_hash=commit_hash,
1054+
dduf_entries=dduf_entries,
1055+
)
1056+
10631057
if not isinstance(resolved_archive_file, list):
10641058
resolved_archive_file = [resolved_archive_file]
10651059

@@ -1084,7 +1078,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10841078
if dtype_orig is not None:
10851079
torch.set_default_dtype(dtype_orig)
10861080

1087-
if not is_sharded and state_dict is None:
1081+
state_dict = None
1082+
if not is_sharded:
10881083
# Time to load the checkpoint
10891084
state_dict = load_state_dict(
10901085
resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries

0 commit comments

Comments
 (0)