From e5ca3a61b4a076e6751a3271d2366a46fe473b05 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 10 Jun 2025 09:27:23 +0530 Subject: [PATCH 1/7] feat: support loading diffusers format gguf checkpoints. --- src/diffusers/loaders/single_file_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 0f762b949d47..cd163d9acf2a 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2145,6 +2145,12 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): if "model.diffusion_model." in k: checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + original_flux = any(k.startswith("double_blocks.") for k in checkpoint) or any( + k.startswith("single_blocks.") for k in checkpoint + ) + if not original_flux: + return checkpoint + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 mlp_ratio = 4.0 From 4ced8799303b69e5dbe791194289dc08f73017b2 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 5 Aug 2025 22:48:25 +0530 Subject: [PATCH 2/7] update --- src/diffusers/loaders/single_file_model.py | 24 ++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 76fefc1260d0..e1c87476aef1 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -156,6 +156,10 @@ } +def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict): + return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys())) + + def _get_single_file_loadable_mapping_class(cls): diffusers_module = importlib.import_module(__name__.split(".")[0]) for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: @@ -381,19 +385,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs} diffusers_model_config.update(model_kwargs) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + model = cls.from_config(diffusers_model_config) + checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) - diffusers_format_checkpoint = checkpoint_mapping_fn( - config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs - ) + + if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): + diffusers_format_checkpoint = checkpoint_mapping_fn( + config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs + ) + else: + diffusers_format_checkpoint = checkpoint + if not diffusers_format_checkpoint: raise SingleFileComponentError( f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - model = cls.from_config(diffusers_model_config) - # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") From 3f67ed08b4a41c183d48b70c059ae34b83587e5f Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 5 Aug 2025 23:27:06 +0530 Subject: [PATCH 3/7] update --- src/diffusers/loaders/single_file_utils.py | 7 +------ tests/quantization/gguf/test_gguf.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index f3005e70ce25..723f0c136f48 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -60,6 +60,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name CHECKPOINT_KEY_NAMES = { + "v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight", "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", @@ -2196,12 +2197,6 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): if "model.diffusion_model." in k: checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) - original_flux = any(k.startswith("double_blocks.") for k in checkpoint) or any( - k.startswith("single_blocks.") for k in checkpoint - ) - if not original_flux: - return checkpoint - num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 mlp_ratio = 4.0 diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index e9d7034f0302..ea2a57bf1ee6 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -212,6 +212,7 @@ def _check_for_gguf_linear(model): class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" + diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf" torch_dtype = torch.bfloat16 model_cls = FluxTransformer2DModel expected_memory_use_in_gb = 5 @@ -296,6 +297,16 @@ def test_pipeline_inference(self): max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) assert max_diff < 1e-4 + def test_loading_gguf_diffusers_format(self): + model = self.model_cls.from_single_file( + self.diffusers_ckpt_path, + subfolder="transformer", + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + config="black-forest-labs/FLUX.1-dev", + ) + model.to("cuda") + model(**self.get_dummy_inputs()) + class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf" From a85f597eded1f3fccced2c8a91bb5864e3a7cb7f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 7 Aug 2025 11:26:39 +0530 Subject: [PATCH 4/7] qwen --- src/diffusers/loaders/single_file_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index e1c87476aef1..dcb00715d59e 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -153,6 +153,10 @@ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "QwenImageTransformer2DModel": { + "checkpoint_mapping_fn": lambda x: x, + "default_subfolder": "transformer", + }, } From 510494907fbf610c3b47cc7365c8003b39a247fc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 7 Aug 2025 14:57:24 +0530 Subject: [PATCH 5/7] up --- docs/source/en/quantization/gguf.md | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md index 71321d556818..642b0707fb74 100644 --- a/docs/source/en/quantization/gguf.md +++ b/docs/source/en/quantization/gguf.md @@ -77,3 +77,40 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels - Q5_K - Q6_K +## Using Diffusers checkpoints + +You can convert a Diffusers checkpoint to GGUF and use it to perform inference. Use the Space below to +run conversion: + + + + + +Once it is obtained, you can run inference: + +```py +import torch + +from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig + +ckpt_path = ( + "https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf" +) +transformer = FluxTransformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=transformer, + torch_dtype=torch.bfloat16, +) +pipe.enable_model_cpu_offload() +prompt = "A cat holding a sign that says hello world" +image = pipe(prompt, generator=torch.manual_seed(0)).images[0] +image.save("flux-gguf.png") +``` \ No newline at end of file From 51dac6f774022cb0e1039699884d964e1005fc77 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 7 Aug 2025 22:23:11 +0530 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Dhruv Nair --- docs/source/en/quantization/gguf.md | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md index 642b0707fb74..f2dba1771da8 100644 --- a/docs/source/en/quantization/gguf.md +++ b/docs/source/en/quantization/gguf.md @@ -77,19 +77,18 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels - Q5_K - Q6_K -## Using Diffusers checkpoints +## Convert to GGUF -You can convert a Diffusers checkpoint to GGUF and use it to perform inference. Use the Space below to +Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference. run conversion: - + - - -Once it is obtained, you can run inference: ```py import torch @@ -102,6 +101,8 @@ ckpt_path = ( transformer = FluxTransformer2DModel.from_single_file( ckpt_path, quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + config="black-forest-labs/FLUX.1-dev", + subfolder="transformer", torch_dtype=torch.bfloat16, ) pipe = FluxPipeline.from_pretrained( From c22779ab7339f6f1c3352f5238f6d9bc5cf65c90 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 8 Aug 2025 08:07:10 +0530 Subject: [PATCH 7/7] up --- docs/source/en/quantization/gguf.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md index f2dba1771da8..47804c102da2 100644 --- a/docs/source/en/quantization/gguf.md +++ b/docs/source/en/quantization/gguf.md @@ -114,4 +114,7 @@ pipe.enable_model_cpu_offload() prompt = "A cat holding a sign that says hello world" image = pipe(prompt, generator=torch.manual_seed(0)).images[0] image.save("flux-gguf.png") -``` \ No newline at end of file +``` + +When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the +model config resides in a `subfolder`, that needs to be specified, too. \ No newline at end of file