Skip to content

Commit 4ced879

Browse files
committed
update
1 parent ba2ba90 commit 4ced879

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@
156156
}
157157

158158

159+
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
160+
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
161+
162+
159163
def _get_single_file_loadable_mapping_class(cls):
160164
diffusers_module = importlib.import_module(__name__.split(".")[0])
161165
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -381,19 +385,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
381385
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
382386
diffusers_model_config.update(model_kwargs)
383387

388+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
389+
with ctx():
390+
model = cls.from_config(diffusers_model_config)
391+
384392
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
385-
diffusers_format_checkpoint = checkpoint_mapping_fn(
386-
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
387-
)
393+
394+
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
395+
diffusers_format_checkpoint = checkpoint_mapping_fn(
396+
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
397+
)
398+
else:
399+
diffusers_format_checkpoint = checkpoint
400+
388401
if not diffusers_format_checkpoint:
389402
raise SingleFileComponentError(
390403
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
391404
)
392-
393-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
394-
with ctx():
395-
model = cls.from_config(diffusers_model_config)
396-
397405
# Check if `_keep_in_fp32_modules` is not None
398406
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
399407
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")

0 commit comments

Comments
 (0)