Skip to content

Commit dba4c1f

Browse files
committed
Add script to convert Flux 2 transformer to diffusers
1 parent 54c6080 commit dba4c1f

File tree

1 file changed

+282
-0
lines changed

1 file changed

+282
-0
lines changed
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
import argparse
2+
import os
3+
import pathlib
4+
from contextlib import nullcontext
5+
from typing import Any, Dict, Optional, Tuple
6+
7+
import safetensors.torch
8+
import torch
9+
from accelerate import init_empty_weights
10+
from huggingface_hub import hf_hub_download
11+
12+
from diffusers import Flux2Transformer2DModel
13+
from diffusers.utils.import_utils import is_accelerate_available
14+
from transformers import Mistral3ForConditionalGeneration, AutoProcessor
15+
16+
17+
"""
18+
# Transformer
19+
"""
20+
21+
22+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
23+
24+
25+
FLUX2_TRANSFORMER_KEYS_RENAME_DICT ={
26+
# Image and text input projections
27+
"img_in": "x_embedder",
28+
"txt_in": "context_embedder",
29+
# Timestep and guidance embeddings
30+
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
31+
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
32+
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
33+
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
34+
# Modulation parameters
35+
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
36+
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
37+
"single_stream_modulation.lin": "single_stream_modulation.linear",
38+
# Final output layer
39+
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
40+
"final_layer.linear": "proj_out",
41+
}
42+
43+
44+
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
45+
"final_layer.adaLN_modulation.1": "norm_out.linear",
46+
}
47+
48+
49+
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
50+
# Handle fused QKV projections separately as we need to break into Q, K, V projections
51+
"img_attn.norm.query_norm": "attn.norm_q",
52+
"img_attn.norm.key_norm": "attn.norm_k",
53+
"img_attn.proj": "attn.to_out.0",
54+
"img_mlp.0": "ff.linear_in",
55+
"img_mlp.2": "ff.linear_out",
56+
"txt_attn.norm.query_norm": "attn.norm_added_q",
57+
"txt_attn.norm.key_norm": "attn.norm_added_k",
58+
"txt_attn.proj": "attn.to_add_out",
59+
"txt_mlp.0": "ff_context.linear_in",
60+
"txt_mlp.2": "ff_context.linear_out",
61+
}
62+
63+
64+
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
65+
"linear1": "attn.to_qkv_mlp_proj",
66+
"linear2": "attn.to_out",
67+
"norm.query_norm": "attn.norm_q",
68+
"norm.key_norm": "attn.norm_k",
69+
}
70+
71+
72+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
73+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
74+
# diffusers implementation
75+
def swap_scale_shift(weight):
76+
shift, scale = weight.chunk(2, dim=0)
77+
new_weight = torch.cat([scale, shift], dim=0)
78+
return new_weight
79+
80+
81+
def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
82+
# Skip if not a weight
83+
if ".weight" not in key:
84+
return
85+
86+
# If adaLN_modulation is in the key, swap scale and shift parameters
87+
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
88+
if "adaLN_modulation" in key:
89+
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
90+
# Assume all such keys are in the AdaLayerNorm key map
91+
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
92+
new_key = ".".join([new_key_without_param_type, param_type])
93+
94+
swapped_weight = swap_scale_shift(state_dict.pop(key))
95+
state_dict[new_key] = swapped_weight
96+
return
97+
98+
99+
def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
100+
# Skip if not a weight, bias, or scale
101+
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
102+
return
103+
104+
new_prefix = "transformer_blocks"
105+
if "double_blocks." in key:
106+
parts = key.split(".")
107+
block_idx = parts[1]
108+
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
109+
within_block_name = ".".join(parts[2:-1])
110+
param_type = parts[-1]
111+
112+
if param_type == "scale":
113+
param_type = "weight"
114+
115+
if "qkv" in within_block_name:
116+
fused_qkv_weight = state_dict.pop(key)
117+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
118+
if "img" in modality_block_name:
119+
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
120+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
121+
new_q_name = "attn.to_q"
122+
new_k_name = "attn.to_k"
123+
new_v_name = "attn.to_v"
124+
elif "txt" in modality_block_name:
125+
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
126+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
127+
new_q_name = "attn.add_q_proj"
128+
new_k_name = "attn.add_k_proj"
129+
new_v_name = "attn.add_v_proj"
130+
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
131+
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
132+
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
133+
state_dict[new_q_key] = to_q_weight
134+
state_dict[new_k_key] = to_k_weight
135+
state_dict[new_v_key] = to_v_weight
136+
else:
137+
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
138+
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
139+
140+
param = state_dict.pop(key)
141+
state_dict[new_key] = param
142+
return
143+
144+
145+
def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
146+
# Skip if not a weight, bias, or scale
147+
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
148+
return
149+
150+
# Mapping:
151+
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
152+
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
153+
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
154+
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
155+
new_prefix = "single_transformer_blocks"
156+
if "single_blocks." in key:
157+
parts = key.split(".")
158+
block_idx = parts[1]
159+
within_block_name = ".".join(parts[2:-1])
160+
param_type = parts[-1]
161+
162+
if param_type == "scale":
163+
param_type = "weight"
164+
165+
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
166+
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
167+
168+
param = state_dict.pop(key)
169+
state_dict[new_key] = param
170+
return
171+
172+
173+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
174+
"adaLN_modulation": convert_ada_layer_norm_weights,
175+
"double_blocks": convert_flux2_double_stream_blocks,
176+
"single_blocks": convert_flux2_single_stream_blocks,
177+
}
178+
179+
180+
def load_original_checkpoint(
181+
repo_id: Optional[str], model_file: Optional[str], checkpoint_path: Optional[str] = None
182+
) -> Dict[str, torch.Tensor]:
183+
if repo_id is not None:
184+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=model_file)
185+
elif checkpoint_path is not None:
186+
ckpt_path = checkpoint_path
187+
else:
188+
raise ValueError("Please provide either `repo_id` or a local `checkpoint_path`")
189+
190+
if "safetensors" in model_file:
191+
original_state_dict = safetensors.torch.load_file(ckpt_path)
192+
else:
193+
original_state_dict = torch.load(ckpt_path, map_location="cpu")
194+
return original_state_dict
195+
196+
197+
def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
198+
state_dict[new_key] = state_dict.pop(old_key)
199+
200+
201+
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
202+
if model_type == "test" or model_type == "dummy-flux2":
203+
config = {
204+
"model_id": "diffusers-internal-dev/dummy-flux2",
205+
"diffusers_config": {
206+
"patch_size": 1,
207+
"in_channels": 128,
208+
"num_layers": 8,
209+
"num_single_layers": 48,
210+
"attention_head_dim": 128,
211+
"num_attention_heads": 48,
212+
"joint_attention_dim": 15360,
213+
"timestep_guidance_channels": 256,
214+
"mlp_ratio": 3.0,
215+
"axes_dims_rope": (32, 32, 32, 32),
216+
"rope_theta": 2000,
217+
"eps": 1e-6,
218+
}
219+
}
220+
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
221+
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
222+
return config, rename_dict, special_keys_remap
223+
224+
225+
def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
226+
config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
227+
228+
diffusers_config = config["diffusers_config"]
229+
230+
with init_empty_weights():
231+
transformer = Flux2Transformer2DModel.from_config(diffusers_config)
232+
233+
# Handle official code --> diffusers key remapping via the remap dict
234+
for key in list(original_state_dict.keys()):
235+
new_key = key[:]
236+
for replace_key, rename_key in rename_dict.items():
237+
new_key = new_key.replace(replace_key, rename_key)
238+
update_state_dict(original_state_dict, key, new_key)
239+
240+
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
241+
# special_keys_remap
242+
for key in list(original_state_dict.keys()):
243+
for special_key, handler_fn_inplace in special_keys_remap.items():
244+
if special_key not in key:
245+
continue
246+
handler_fn_inplace(key, original_state_dict)
247+
248+
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
249+
return transformer
250+
251+
252+
def parse_args():
253+
parser = argparse.ArgumentParser()
254+
parser.add_argument("--original_state_dict_repo_id", default="diffusers-internal-dev/dummy-flux2", type=str)
255+
parser.add_argument("--filename", default="flux.safetensors", type=str)
256+
parser.add_argument("--checkpoint_path", default=None, type=str)
257+
258+
parser.add_argument("--model_type", type=str, default="test")
259+
parser.add_argument("--vae", action="store_true")
260+
parser.add_argument("--transformer", action="store_true")
261+
262+
parser.add_argument("--dtype", type=str, default="bf16")
263+
264+
parser.add_argument("--output_path", type=str)
265+
266+
args = parser.parse_args()
267+
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
268+
269+
return args
270+
271+
272+
def main(args):
273+
original_ckpt = load_original_checkpoint(args.original_state_dict_repo_id, args.filename, args.checkpoint_path)
274+
275+
if args.transformer:
276+
transformer = convert_flux2_transformer_to_diffusers(original_ckpt, args.model_type)
277+
transformer.to(args.dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
278+
279+
280+
if __name__ == "__main__":
281+
args = parse_args()
282+
main(args)

0 commit comments

Comments
 (0)