Skip to content
28 changes: 20 additions & 8 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,17 @@
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
}


def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))


def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
Expand Down Expand Up @@ -381,19 +389,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)

checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)

if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint

if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

CHECKPOINT_KEY_NAMES = {
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
Expand Down
11 changes: 11 additions & 0 deletions tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _check_for_gguf_linear(model):

class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
torch_dtype = torch.bfloat16
model_cls = FluxTransformer2DModel
expected_memory_use_in_gb = 5
Expand Down Expand Up @@ -296,6 +297,16 @@ def test_pipeline_inference(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4

def test_loading_gguf_diffusers_format(self):
model = self.model_cls.from_single_file(
self.diffusers_ckpt_path,
subfolder="transformer",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="black-forest-labs/FLUX.1-dev",
)
model.to("cuda")
model(**self.get_dummy_inputs())


class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
Expand Down
Loading