Skip to content

Commit afbbeff

Browse files
Add support for loading AuraFlow models from GGUF
https://huggingface.co/city96/AuraFlow-v0.3-gguf
1 parent 91008aa commit afbbeff

File tree

5 files changed

+103
-3
lines changed

5 files changed

+103
-3
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .single_file_utils import (
2626
SingleFileComponentError,
2727
convert_animatediff_checkpoint_to_diffusers,
28+
convert_auraflow_transformer_checkpoint_to_diffusers,
2829
convert_autoencoder_dc_checkpoint_to_diffusers,
2930
convert_controlnet_checkpoint,
3031
convert_flux_transformer_checkpoint_to_diffusers,
@@ -106,6 +107,10 @@
106107
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
107108
"default_subfolder": "transformer",
108109
},
110+
"AuraFlowTransformer2DModel": {
111+
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
112+
"default_subfolder": "transformer",
113+
},
109114
}
110115

111116

src/diffusers/loaders/single_file_utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,6 +2082,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
20822082
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
20832083
converted_state_dict = {}
20842084
keys = list(checkpoint.keys())
2085+
20852086
for k in keys:
20862087
if "model.diffusion_model." in k:
20872088
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2681,3 +2682,95 @@ def update_state_dict_(state_dict, old_key, new_key):
26812682
handler_fn_inplace(key, checkpoint)
26822683

26832684
return checkpoint
2685+
2686+
2687+
def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, config=None, **kwargs):
2688+
converted_state_dict = {}
2689+
state_dict_keys = list(checkpoint.keys())
2690+
2691+
# Handle register tokens and positional embeddings
2692+
converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
2693+
2694+
# Handle time step projection
2695+
converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
2696+
converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
2697+
converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
2698+
converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
2699+
2700+
# Handle context embedder
2701+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
2702+
2703+
# Calculate the number of layers
2704+
def calculate_layers(keys, key_prefix):
2705+
layers = set()
2706+
for k in keys:
2707+
if key_prefix in k:
2708+
layer_num = int(k.split(".")[1]) # get the layer number
2709+
layers.add(layer_num)
2710+
return len(layers)
2711+
2712+
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
2713+
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
2714+
2715+
# MMDiT blocks
2716+
for i in range(mmdit_layers):
2717+
# Feed-forward
2718+
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
2719+
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
2720+
for orig_k, diffuser_k in path_mapping.items():
2721+
for k, v in weight_mapping.items():
2722+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
2723+
f"double_layers.{i}.{orig_k}.{k}.weight", None
2724+
)
2725+
2726+
# Norms
2727+
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
2728+
for orig_k, diffuser_k in path_mapping.items():
2729+
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
2730+
f"double_layers.{i}.{orig_k}.1.weight", None
2731+
)
2732+
2733+
# Attentions
2734+
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
2735+
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
2736+
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
2737+
for k, v in attn_mapping.items():
2738+
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
2739+
f"double_layers.{i}.attn.{k}.weight", None
2740+
)
2741+
2742+
# Single-DiT blocks
2743+
for i in range(single_dit_layers):
2744+
# Feed-forward
2745+
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
2746+
for k, v in mapping.items():
2747+
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
2748+
f"single_layers.{i}.mlp.{k}.weight", None
2749+
)
2750+
2751+
# Norms
2752+
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
2753+
f"single_layers.{i}.modCX.1.weight", None
2754+
)
2755+
2756+
# Attentions
2757+
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
2758+
for k, v in x_attn_mapping.items():
2759+
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
2760+
f"single_layers.{i}.attn.{k}.weight", None
2761+
)
2762+
# Final blocks
2763+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
2764+
2765+
# Handle the final norm layer
2766+
norm_weight = checkpoint.pop("modF.1.weight", None)
2767+
if norm_weight is not None:
2768+
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
2769+
else:
2770+
converted_state_dict["norm_out.linear.weight"] = None
2771+
2772+
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
2773+
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
2774+
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
2775+
2776+
return converted_state_dict

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23+
from ...loaders import FromOriginalModelMixin
2324
from ...utils import is_torch_version, logging
2425
from ...utils.torch_utils import maybe_allow_in_graph
2526
from ..attention_processor import (
@@ -253,7 +254,7 @@ def forward(
253254
return encoder_hidden_states, hidden_states
254255

255256

256-
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
257+
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257258
r"""
258259
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
259260

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from transformers import T5Tokenizer, UMT5EncoderModel
1919

2020
from ...image_processor import VaeImageProcessor
21+
from ...loaders import FromSingleFileMixin
2122
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2223
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
2324
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -111,7 +112,7 @@ def retrieve_timesteps(
111112
return timesteps, num_inference_steps
112113

113114

114-
class AuraFlowPipeline(DiffusionPipeline):
115+
class AuraFlowPipeline(DiffusionPipeline, FromSingleFileMixin):
115116
r"""
116117
Args:
117118
tokenizer (`T5TokenizerFast`):

src/diffusers/quantizers/gguf/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def __init__(
450450
def forward(self, inputs):
451451
weight = dequantize_gguf_tensor(self.weight)
452452
weight = weight.to(self.compute_dtype)
453-
bias = self.bias.to(self.compute_dtype)
453+
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
454454

455455
output = torch.nn.functional.linear(inputs, weight, bias)
456456
return output

0 commit comments

Comments
 (0)