Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/source/en/quantization/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,40 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels
- Q5_K
- Q6_K

## Using Diffusers checkpoints

You can convert a Diffusers checkpoint to GGUF and use it to perform inference. Use the Space below to
run conversion:

<script
type="module"
src="https://gradio.s3-us-west-2.amazonaws.com/5.41.1/gradio.js"
></script>

<gradio-app src="https://diffusers-internal-dev-diffusers-to-gguf.hf.space"></gradio-app>

Once it is obtained, you can run inference:

```py
import torch

from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig

ckpt_path = (
"https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
image.save("flux-gguf.png")
```
28 changes: 20 additions & 8 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,17 @@
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
}


def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))


def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
Expand Down Expand Up @@ -381,19 +389,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)

checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)

if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
else:
diffusers_format_checkpoint = checkpoint

if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

CHECKPOINT_KEY_NAMES = {
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
Expand Down
11 changes: 11 additions & 0 deletions tests/quantization/gguf/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _check_for_gguf_linear(model):

class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
torch_dtype = torch.bfloat16
model_cls = FluxTransformer2DModel
expected_memory_use_in_gb = 5
Expand Down Expand Up @@ -296,6 +297,16 @@ def test_pipeline_inference(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4

def test_loading_gguf_diffusers_format(self):
model = self.model_cls.from_single_file(
self.diffusers_ckpt_path,
subfolder="transformer",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
config="black-forest-labs/FLUX.1-dev",
)
model.to("cuda")
model(**self.get_dummy_inputs())


class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
Expand Down
Loading