Skip to content

Commit f20aba3

Browse files
sayakpaulDN6
andauthored
[GGUF] feat: support loading diffusers format gguf checkpoints. (#11684)
* feat: support loading diffusers format gguf checkpoints. * update * update * qwen --------- Co-authored-by: DN6 <[email protected]>
1 parent ccf2c31 commit f20aba3

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,17 @@
153153
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
154154
"default_subfolder": "transformer",
155155
},
156+
"QwenImageTransformer2DModel": {
157+
"checkpoint_mapping_fn": lambda x: x,
158+
"default_subfolder": "transformer",
159+
},
156160
}
157161

158162

163+
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
164+
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
165+
166+
159167
def _get_single_file_loadable_mapping_class(cls):
160168
diffusers_module = importlib.import_module(__name__.split(".")[0])
161169
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -381,19 +389,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
381389
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
382390
diffusers_model_config.update(model_kwargs)
383391

392+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
393+
with ctx():
394+
model = cls.from_config(diffusers_model_config)
395+
384396
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
385-
diffusers_format_checkpoint = checkpoint_mapping_fn(
386-
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
387-
)
397+
398+
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
399+
diffusers_format_checkpoint = checkpoint_mapping_fn(
400+
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
401+
)
402+
else:
403+
diffusers_format_checkpoint = checkpoint
404+
388405
if not diffusers_format_checkpoint:
389406
raise SingleFileComponentError(
390407
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
391408
)
392-
393-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
394-
with ctx():
395-
model = cls.from_config(diffusers_model_config)
396-
397409
# Check if `_keep_in_fp32_modules` is not None
398410
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
399411
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")

src/diffusers/loaders/single_file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6161

6262
CHECKPOINT_KEY_NAMES = {
63+
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
6364
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
6465
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
6566
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",

tests/quantization/gguf/test_gguf.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _check_for_gguf_linear(model):
212212

213213
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
214214
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
215+
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
215216
torch_dtype = torch.bfloat16
216217
model_cls = FluxTransformer2DModel
217218
expected_memory_use_in_gb = 5
@@ -296,6 +297,16 @@ def test_pipeline_inference(self):
296297
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
297298
assert max_diff < 1e-4
298299

300+
def test_loading_gguf_diffusers_format(self):
301+
model = self.model_cls.from_single_file(
302+
self.diffusers_ckpt_path,
303+
subfolder="transformer",
304+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
305+
config="black-forest-labs/FLUX.1-dev",
306+
)
307+
model.to("cuda")
308+
model(**self.get_dummy_inputs())
309+
299310

300311
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
301312
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"

0 commit comments

Comments
 (0)