11import argparse
22import os
3- import pathlib
43from contextlib import nullcontext
54from typing import Any , Dict , Optional , Tuple
65
1110
1211from diffusers import Flux2Transformer2DModel
1312from diffusers .utils .import_utils import is_accelerate_available
14- from transformers import Mistral3ForConditionalGeneration , AutoProcessor
1513
1614
1715"""
2220CTX = init_empty_weights if is_accelerate_available () else nullcontext
2321
2422
25- FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
23+ FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
2624 # Image and text input projections
2725 "img_in" : "x_embedder" ,
2826 "txt_in" : "context_embedder" ,
@@ -82,7 +80,7 @@ def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None
8280 # Skip if not a weight
8381 if ".weight" not in key :
8482 return
85-
83+
8684 # If adaLN_modulation is in the key, swap scale and shift parameters
8785 # Original implementation is (shift, scale); diffusers implementation is (scale, shift)
8886 if "adaLN_modulation" in key :
@@ -100,7 +98,7 @@ def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) ->
10098 # Skip if not a weight, bias, or scale
10199 if ".weight" not in key and ".bias" not in key and ".scale" not in key :
102100 return
103-
101+
104102 new_prefix = "transformer_blocks"
105103 if "double_blocks." in key :
106104 parts = key .split ("." )
@@ -111,7 +109,7 @@ def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) ->
111109
112110 if param_type == "scale" :
113111 param_type = "weight"
114-
112+
115113 if "qkv" in within_block_name :
116114 fused_qkv_weight = state_dict .pop (key )
117115 to_q_weight , to_k_weight , to_v_weight = torch .chunk (fused_qkv_weight , 3 , dim = 0 )
@@ -146,7 +144,7 @@ def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) ->
146144 # Skip if not a weight, bias, or scale
147145 if ".weight" not in key and ".bias" not in key and ".scale" not in key :
148146 return
149-
147+
150148 # Mapping:
151149 # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
152150 # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
@@ -215,7 +213,7 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
215213 "axes_dims_rope" : (32 , 32 , 32 , 32 ),
216214 "rope_theta" : 2000 ,
217215 "eps" : 1e-6 ,
218- }
216+ },
219217 }
220218 rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
221219 special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
0 commit comments