Skip to content

Commit 6d16975

Browse files
committed
add sf support
1 parent ac4f61a commit 6d16975

File tree

2 files changed

+175
-0
lines changed

2 files changed

+175
-0
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
convert_chroma_transformer_checkpoint_to_diffusers,
3535
convert_controlnet_checkpoint,
3636
convert_cosmos_transformer_checkpoint_to_diffusers,
37+
convert_flux2_transformer_checkpoint_to_diffusers,
3738
convert_flux_transformer_checkpoint_to_diffusers,
3839
convert_hidream_transformer_to_diffusers,
3940
convert_hunyuan_video_transformer_to_diffusers,
@@ -162,6 +163,10 @@
162163
"checkpoint_mapping_fn": lambda x: x,
163164
"default_subfolder": "transformer",
164165
},
166+
"Flux2Transformer2DModel": {
167+
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
168+
"default_subfolder": "transformer",
169+
},
165170
}
166171

167172

src/diffusers/loaders/single_file_utils.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
"net.blocks.0.self_attn.q_proj.weight",
141141
"net.pos_embedder.dim_spatial_range",
142142
],
143+
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
143144
}
144145

145146
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -189,6 +190,7 @@
189190
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
190191
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
191192
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
193+
"flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
192194
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
193195
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
194196
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint):
649651
else:
650652
model_type = "animatediff_v3"
651653

654+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
655+
model_type = "flux-2-dev"
656+
652657
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
653658
if any(
654659
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
@@ -3647,3 +3652,168 @@ def rename_transformer_blocks_(key: str, state_dict):
36473652
handler_fn_inplace(key, converted_state_dict)
36483653

36493654
return converted_state_dict
3655+
3656+
3657+
def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
3658+
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
3659+
# Image and text input projections
3660+
"img_in": "x_embedder",
3661+
"txt_in": "context_embedder",
3662+
# Timestep and guidance embeddings
3663+
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
3664+
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
3665+
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
3666+
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
3667+
# Modulation parameters
3668+
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
3669+
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
3670+
"single_stream_modulation.lin": "single_stream_modulation.linear",
3671+
# Final output layer
3672+
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
3673+
"final_layer.linear": "proj_out",
3674+
}
3675+
3676+
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
3677+
"final_layer.adaLN_modulation.1": "norm_out.linear",
3678+
}
3679+
3680+
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
3681+
# Handle fused QKV projections separately as we need to break into Q, K, V projections
3682+
"img_attn.norm.query_norm": "attn.norm_q",
3683+
"img_attn.norm.key_norm": "attn.norm_k",
3684+
"img_attn.proj": "attn.to_out.0",
3685+
"img_mlp.0": "ff.linear_in",
3686+
"img_mlp.2": "ff.linear_out",
3687+
"txt_attn.norm.query_norm": "attn.norm_added_q",
3688+
"txt_attn.norm.key_norm": "attn.norm_added_k",
3689+
"txt_attn.proj": "attn.to_add_out",
3690+
"txt_mlp.0": "ff_context.linear_in",
3691+
"txt_mlp.2": "ff_context.linear_out",
3692+
}
3693+
3694+
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
3695+
"linear1": "attn.to_qkv_mlp_proj",
3696+
"linear2": "attn.to_out",
3697+
"norm.query_norm": "attn.norm_q",
3698+
"norm.key_norm": "attn.norm_k",
3699+
}
3700+
3701+
def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
3702+
# Skip if not a weight, bias, or scale
3703+
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
3704+
return
3705+
3706+
# Mapping:
3707+
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
3708+
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
3709+
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
3710+
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
3711+
new_prefix = "single_transformer_blocks"
3712+
if "single_blocks." in key:
3713+
parts = key.split(".")
3714+
block_idx = parts[1]
3715+
within_block_name = ".".join(parts[2:-1])
3716+
param_type = parts[-1]
3717+
3718+
if param_type == "scale":
3719+
param_type = "weight"
3720+
3721+
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
3722+
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
3723+
3724+
param = state_dict.pop(key)
3725+
state_dict[new_key] = param
3726+
3727+
return
3728+
3729+
def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
3730+
# Skip if not a weight
3731+
if ".weight" not in key:
3732+
return
3733+
3734+
# If adaLN_modulation is in the key, swap scale and shift parameters
3735+
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
3736+
if "adaLN_modulation" in key:
3737+
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
3738+
# Assume all such keys are in the AdaLayerNorm key map
3739+
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
3740+
new_key = ".".join([new_key_without_param_type, param_type])
3741+
3742+
swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
3743+
state_dict[new_key] = swapped_weight
3744+
3745+
return
3746+
3747+
def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
3748+
# Skip if not a weight, bias, or scale
3749+
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
3750+
return
3751+
3752+
new_prefix = "transformer_blocks"
3753+
if "double_blocks." in key:
3754+
parts = key.split(".")
3755+
block_idx = parts[1]
3756+
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
3757+
within_block_name = ".".join(parts[2:-1])
3758+
param_type = parts[-1]
3759+
3760+
if param_type == "scale":
3761+
param_type = "weight"
3762+
3763+
if "qkv" in within_block_name:
3764+
fused_qkv_weight = state_dict.pop(key)
3765+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
3766+
if "img" in modality_block_name:
3767+
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
3768+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
3769+
new_q_name = "attn.to_q"
3770+
new_k_name = "attn.to_k"
3771+
new_v_name = "attn.to_v"
3772+
elif "txt" in modality_block_name:
3773+
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
3774+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
3775+
new_q_name = "attn.add_q_proj"
3776+
new_k_name = "attn.add_k_proj"
3777+
new_v_name = "attn.add_v_proj"
3778+
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
3779+
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
3780+
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
3781+
state_dict[new_q_key] = to_q_weight
3782+
state_dict[new_k_key] = to_k_weight
3783+
state_dict[new_v_key] = to_v_weight
3784+
else:
3785+
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
3786+
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
3787+
3788+
param = state_dict.pop(key)
3789+
state_dict[new_key] = param
3790+
return
3791+
3792+
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
3793+
state_dict[new_key] = state_dict.pop(old_key)
3794+
3795+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
3796+
"adaLN_modulation": convert_ada_layer_norm_weights,
3797+
"double_blocks": convert_flux2_double_stream_blocks,
3798+
"single_blocks": convert_flux2_single_stream_blocks,
3799+
}
3800+
3801+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
3802+
3803+
# Handle official code --> diffusers key remapping via the remap dict
3804+
for key in list(converted_state_dict.keys()):
3805+
new_key = key[:]
3806+
for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
3807+
new_key = new_key.replace(replace_key, rename_key)
3808+
3809+
update_state_dict(converted_state_dict, key, new_key)
3810+
3811+
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
3812+
# special_keys_remap
3813+
for key in list(converted_state_dict.keys()):
3814+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
3815+
if special_key not in key:
3816+
continue
3817+
handler_fn_inplace(key, converted_state_dict)
3818+
3819+
return converted_state_dict

0 commit comments

Comments
 (0)