diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md index fbd822af783e..9134ccf86b79 100644 --- a/docs/source/en/api/pipelines/lumina2.md +++ b/docs/source/en/api/pipelines/lumina2.md @@ -26,6 +26,56 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) +## Using Single File loading with Lumina Image 2.0 + +Single file loading for Lumina Image 2.0 is available for the `Lumina2Transformer2DModel` + +```python +import torch +from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline + +ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth" +transformer = Lumina2Transformer2DModel.from_single_file( + ckpt_path, torch_dtype=torch.bfloat16 +) + +pipe = Lumina2Text2ImgPipeline.from_pretrained( + "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 +) +pipe.enable_model_cpu_offload() +image = pipe( + "a cat holding a sign that says hello", + generator=torch.Generator("cpu").manual_seed(0), +).images[0] +image.save("lumina-single-file.png") + +``` + +## Using GGUF Quantized Checkpoints with Lumina Image 2.0 + +GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig` + +```python +from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig + +ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf" +transformer = Lumina2Transformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) + +pipe = Lumina2Text2ImgPipeline.from_pretrained( + "Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16 +) +pipe.enable_model_cpu_offload() +image = pipe( + "a cat holding a sign that says hello", + generator=torch.Generator("cpu").manual_seed(0), +).images[0] +image.save("lumina-gguf.png") +``` + ## Lumina2Text2ImgPipeline [[autodoc]] Lumina2Text2ImgPipeline diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index c7d0fcb3046e..4a5c25676fb1 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -34,6 +34,7 @@ convert_ldm_vae_checkpoint, convert_ltx_transformer_checkpoint_to_diffusers, convert_ltx_vae_checkpoint_to_diffusers, + convert_lumina2_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, @@ -111,6 +112,10 @@ "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "Lumina2Transformer2DModel": { + "checkpoint_mapping_fn": convert_lumina2_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 731b7b87f625..e18ea1374fb4 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -116,6 +116,7 @@ "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"], "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias", "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight", + "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], } DIFFUSERS_DEFAULT_PIPELINE_PATHS = { @@ -174,6 +175,7 @@ "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"}, "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"}, "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"}, + "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, } # Use to configure model sample size when original config is provided @@ -657,6 +659,9 @@ def infer_diffusers_model_type(checkpoint): ): model_type = "instruct-pix2pix" + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): + model_type = "lumina2" + else: model_type = "v1" @@ -2798,3 +2803,75 @@ def calculate_layers(keys, key_prefix): converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") return converted_state_dict + + +def convert_lumina2_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + + # Original Lumina-Image-2 has an extra norm paramter that is unused + # We just remove it here + checkpoint.pop("norm_final.weight", None) + + # Comfy checkpoints add this prefix + keys = list(checkpoint.keys()) + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + LUMINA_KEY_MAP = { + "cap_embedder": "time_caption_embed.caption_embedder", + "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1", + "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2", + "attention": "attn", + ".out.": ".to_out.0.", + "k_norm": "norm_k", + "q_norm": "norm_q", + "w1": "linear_1", + "w2": "linear_2", + "w3": "linear_3", + "adaLN_modulation.1": "norm1.linear", + } + ATTENTION_NORM_MAP = { + "attention_norm1": "norm1.norm", + "attention_norm2": "norm2", + } + CONTEXT_REFINER_MAP = { + "context_refiner.0.attention_norm1": "context_refiner.0.norm1", + "context_refiner.0.attention_norm2": "context_refiner.0.norm2", + "context_refiner.1.attention_norm1": "context_refiner.1.norm1", + "context_refiner.1.attention_norm2": "context_refiner.1.norm2", + } + FINAL_LAYER_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear_1", + "final_layer.linear": "norm_out.linear_2", + } + + def convert_lumina_attn_to_diffusers(tensor, diffusers_key): + q_dim = 2304 + k_dim = v_dim = 768 + + to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0) + + return { + diffusers_key.replace("qkv", "to_q"): to_q, + diffusers_key.replace("qkv", "to_k"): to_k, + diffusers_key.replace("qkv", "to_v"): to_v, + } + + for key in keys: + diffusers_key = key + for k, v in CONTEXT_REFINER_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in FINAL_LAYER_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in ATTENTION_NORM_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + for k, v in LUMINA_KEY_MAP.items(): + diffusers_key = diffusers_key.replace(k, v) + + if "qkv" in diffusers_key: + converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key)) + else: + converted_state_dict[diffusers_key] = checkpoint.pop(key) + + return converted_state_dict diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index bd0848a2d63f..9a9aaa02d583 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -21,6 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import logging from ..attention import LuminaFeedForward from ..attention_processor import Attention @@ -333,7 +334,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): ) -class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" Lumina2NextDiT: Diffusion model with a Transformer backbone. diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py new file mode 100644 index 000000000000..78e68c4c2df0 --- /dev/null +++ b/tests/single_file/test_lumina2_transformer.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch + +from diffusers import ( + Lumina2Transformer2DModel, +) +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + torch_device, +) + + +enable_full_determinism() + + +@require_torch_accelerator +class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase): + model_class = Lumina2Transformer2DModel + ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors" + alternate_keys_ckpt_paths = [ + "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors" + ] + + repo_id = "Alpha-VLLM/Lumina-Image-2.0" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_single_file_components(self): + model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer") + model_single_file = self.model_class.from_single_file(self.ckpt_path) + + PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] + for param_name, param_value in model_single_file.config.items(): + if param_name in PARAMS_TO_IGNORE: + continue + assert ( + model.config[param_name] == param_value + ), f"{param_name} differs between single file loading and pretrained loading" + + def test_checkpoint_loading(self): + for ckpt_path in self.alternate_keys_ckpt_paths: + torch.cuda.empty_cache() + model = self.model_class.from_single_file(ckpt_path) + + del model + gc.collect() + torch.cuda.empty_cache()