|
153 | 153 | "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
154 | 154 | "default_subfolder": "transformer",
|
155 | 155 | },
|
| 156 | + "QwenImageTransformer2DModel": { |
| 157 | + "checkpoint_mapping_fn": lambda x: x, |
| 158 | + "default_subfolder": "transformer", |
| 159 | + }, |
156 | 160 | }
|
157 | 161 |
|
158 | 162 |
|
| 163 | +def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict): |
| 164 | + return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys())) |
| 165 | + |
| 166 | + |
159 | 167 | def _get_single_file_loadable_mapping_class(cls):
|
160 | 168 | diffusers_module = importlib.import_module(__name__.split(".")[0])
|
161 | 169 | for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
@@ -381,19 +389,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
|
381 | 389 | model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
382 | 390 | diffusers_model_config.update(model_kwargs)
|
383 | 391 |
|
| 392 | + ctx = init_empty_weights if is_accelerate_available() else nullcontext |
| 393 | + with ctx(): |
| 394 | + model = cls.from_config(diffusers_model_config) |
| 395 | + |
384 | 396 | checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
385 |
| - diffusers_format_checkpoint = checkpoint_mapping_fn( |
386 |
| - config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs |
387 |
| - ) |
| 397 | + |
| 398 | + if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): |
| 399 | + diffusers_format_checkpoint = checkpoint_mapping_fn( |
| 400 | + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs |
| 401 | + ) |
| 402 | + else: |
| 403 | + diffusers_format_checkpoint = checkpoint |
| 404 | + |
388 | 405 | if not diffusers_format_checkpoint:
|
389 | 406 | raise SingleFileComponentError(
|
390 | 407 | f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
391 | 408 | )
|
392 |
| - |
393 |
| - ctx = init_empty_weights if is_accelerate_available() else nullcontext |
394 |
| - with ctx(): |
395 |
| - model = cls.from_config(diffusers_model_config) |
396 |
| - |
397 | 409 | # Check if `_keep_in_fp32_modules` is not None
|
398 | 410 | use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
399 | 411 | (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
|
0 commit comments