|
156 | 156 | } |
157 | 157 |
|
158 | 158 |
|
| 159 | +def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict): |
| 160 | + return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys())) |
| 161 | + |
| 162 | + |
159 | 163 | def _get_single_file_loadable_mapping_class(cls): |
160 | 164 | diffusers_module = importlib.import_module(__name__.split(".")[0]) |
161 | 165 | for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: |
@@ -381,19 +385,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = |
381 | 385 | model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} |
382 | 386 | diffusers_model_config.update(model_kwargs) |
383 | 387 |
|
| 388 | + ctx = init_empty_weights if is_accelerate_available() else nullcontext |
| 389 | + with ctx(): |
| 390 | + model = cls.from_config(diffusers_model_config) |
| 391 | + |
384 | 392 | checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) |
385 | | - diffusers_format_checkpoint = checkpoint_mapping_fn( |
386 | | - config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs |
387 | | - ) |
| 393 | + |
| 394 | + if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): |
| 395 | + diffusers_format_checkpoint = checkpoint_mapping_fn( |
| 396 | + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs |
| 397 | + ) |
| 398 | + else: |
| 399 | + diffusers_format_checkpoint = checkpoint |
| 400 | + |
388 | 401 | if not diffusers_format_checkpoint: |
389 | 402 | raise SingleFileComponentError( |
390 | 403 | f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." |
391 | 404 | ) |
392 | | - |
393 | | - ctx = init_empty_weights if is_accelerate_available() else nullcontext |
394 | | - with ctx(): |
395 | | - model = cls.from_config(diffusers_model_config) |
396 | | - |
397 | 405 | # Check if `_keep_in_fp32_modules` is not None |
398 | 406 | use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( |
399 | 407 | (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") |
|
0 commit comments