diff --git a/docs/source/en/api/pipelines/aura_flow.md b/docs/source/en/api/pipelines/aura_flow.md index c1cf6aa263a7..5d58690505b3 100644 --- a/docs/source/en/api/pipelines/aura_flow.md +++ b/docs/source/en/api/pipelines/aura_flow.md @@ -62,6 +62,33 @@ image = pipeline(prompt).images[0] image.save("auraflow.png") ``` +Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported: + +```py +import torch +from diffusers import ( + AuraFlowPipeline, + GGUFQuantizationConfig, + AuraFlowTransformer2DModel, +) + +transformer = AuraFlowTransformer2DModel.from_single_file( + "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf", + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) + +pipeline = AuraFlowPipeline.from_pretrained( + "fal/AuraFlow-v0.3", + transformer=transformer, + torch_dtype=torch.bfloat16, +) + +prompt = "a cute pony in a field of flowers" +image = pipeline(prompt).images[0] +image.save("auraflow.png") +``` + ## AuraFlowPipeline [[autodoc]] AuraFlowPipeline diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 79dc2691b9e4..b65069e60d50 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -25,6 +25,7 @@ from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, + convert_auraflow_transformer_checkpoint_to_diffusers, convert_autoencoder_dc_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_flux_transformer_checkpoint_to_diffusers, @@ -106,6 +107,10 @@ "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, "default_subfolder": "transformer", }, + "AuraFlowTransformer2DModel": { + "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 1fa1bdf259cc..cefba48275cf 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -94,6 +94,12 @@ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", "animatediff_rgb": "controlnet_cond_embedding.weight", + "auraflow": [ + "double_layers.0.attn.w2q.weight", + "double_layers.0.attn.w1q.weight", + "cond_seq_linear.weight", + "t_embedder.mlp.0.weight", + ], "flux": [ "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", @@ -154,6 +160,7 @@ "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, + "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, @@ -635,6 +642,9 @@ def infer_diffusers_model_type(checkpoint): elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: model_type = "hunyuan-video" + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]): + model_type = "auraflow" + elif ( CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8 @@ -2090,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) + for k in keys: if "model.diffusion_model." in k: checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) @@ -2689,3 +2700,95 @@ def update_state_dict_(state_dict, old_key, new_key): handler_fn_inplace(key, checkpoint) return checkpoint + + +def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + state_dict_keys = list(checkpoint.keys()) + + # Handle register tokens and positional embeddings + converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) + + # Handle time step projection + converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) + converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) + converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) + converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) + + # Handle context embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) + + # Calculate the number of layers + def calculate_layers(keys, key_prefix): + layers = set() + for k in keys: + if key_prefix in k: + layer_num = int(k.split(".")[1]) # get the layer number + layers.add(layer_num) + return len(layers) + + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + # MMDiT blocks + for i in range(mmdit_layers): + # Feed-forward + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for orig_k, diffuser_k in path_mapping.items(): + for k, v in weight_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.{k}.weight", None + ) + + # Norms + path_mapping = {"modX": "norm1", "modC": "norm1_context"} + for orig_k, diffuser_k in path_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} + for attn_mapping in [x_attn_mapping, context_attn_mapping]: + for k, v in attn_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.attn.{k}.weight", None + ) + + # Single-DiT blocks + for i in range(single_dit_layers): + # Feed-forward + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for k, v in mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.mlp.{k}.weight", None + ) + + # Norms + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"single_layers.{i}.modCX.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + for k, v in x_attn_mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.attn.{k}.weight", None + ) + # Final blocks + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) + + # Handle the final norm layer + norm_weight = checkpoint.pop("modF.1.weight", None) + if norm_weight is not None: + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) + else: + converted_state_dict["norm_out.linear.weight"] = None + + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") + + return converted_state_dict diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index b3f29e6b6224..b35488a89282 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( @@ -253,7 +254,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 35e5743fbcf0..9bbb5e4ca266 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -450,7 +450,7 @@ def __init__( def forward(self, inputs): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) - bias = self.bias.to(self.compute_dtype) + bias = self.bias.to(self.compute_dtype) if self.bias is not None else None output = torch.nn.functional.linear(inputs, weight, bias) return output diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 8ac4c9915c27..8f768b10e846 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -6,6 +6,8 @@ import torch.nn as nn from diffusers import ( + AuraFlowPipeline, + AuraFlowTransformer2DModel, FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig, @@ -54,7 +56,8 @@ def test_gguf_linear_layers(self): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"): assert module.weight.dtype == torch.uint8 - assert module.bias.dtype == torch.float32 + if module.bias is not None: + assert module.bias.dtype == torch.float32 def test_gguf_memory_usage(self): quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) @@ -377,3 +380,79 @@ def test_pipeline_inference(self): ) max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) assert max_diff < 1e-4 + + +class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf" + torch_dtype = torch.bfloat16 + model_cls = AuraFlowTransformer2DModel + expected_memory_use_in_gb = 4 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 2048), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = AuraFlowPipeline.from_pretrained( + "fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a pony holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.46484375, + 0.546875, + 0.64453125, + 0.48242188, + 0.53515625, + 0.59765625, + 0.47070312, + 0.5078125, + 0.5703125, + 0.42773438, + 0.50390625, + 0.5703125, + 0.47070312, + 0.515625, + 0.57421875, + 0.45898438, + 0.48632812, + 0.53515625, + 0.4453125, + 0.5078125, + 0.56640625, + 0.47851562, + 0.5234375, + 0.57421875, + 0.48632812, + 0.5234375, + 0.56640625, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4