@@ -2082,6 +2082,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
20822082def convert_flux_transformer_checkpoint_to_diffusers (checkpoint , ** kwargs ):
20832083 converted_state_dict = {}
20842084 keys = list (checkpoint .keys ())
2085+
20852086 for k in keys :
20862087 if "model.diffusion_model." in k :
20872088 checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
@@ -2681,3 +2682,95 @@ def update_state_dict_(state_dict, old_key, new_key):
26812682 handler_fn_inplace (key , checkpoint )
26822683
26832684 return checkpoint
2685+
2686+
2687+ def convert_auraflow_transformer_checkpoint_to_diffusers (checkpoint , config = None , ** kwargs ):
2688+ converted_state_dict = {}
2689+ state_dict_keys = list (checkpoint .keys ())
2690+
2691+ # Handle register tokens and positional embeddings
2692+ converted_state_dict ["register_tokens" ] = checkpoint .pop ("register_tokens" , None )
2693+
2694+ # Handle time step projection
2695+ converted_state_dict ["time_step_proj.linear_1.weight" ] = checkpoint .pop ("t_embedder.mlp.0.weight" , None )
2696+ converted_state_dict ["time_step_proj.linear_1.bias" ] = checkpoint .pop ("t_embedder.mlp.0.bias" , None )
2697+ converted_state_dict ["time_step_proj.linear_2.weight" ] = checkpoint .pop ("t_embedder.mlp.2.weight" , None )
2698+ converted_state_dict ["time_step_proj.linear_2.bias" ] = checkpoint .pop ("t_embedder.mlp.2.bias" , None )
2699+
2700+ # Handle context embedder
2701+ converted_state_dict ["context_embedder.weight" ] = checkpoint .pop ("cond_seq_linear.weight" , None )
2702+
2703+ # Calculate the number of layers
2704+ def calculate_layers (keys , key_prefix ):
2705+ layers = set ()
2706+ for k in keys :
2707+ if key_prefix in k :
2708+ layer_num = int (k .split ("." )[1 ]) # get the layer number
2709+ layers .add (layer_num )
2710+ return len (layers )
2711+
2712+ mmdit_layers = calculate_layers (state_dict_keys , key_prefix = "double_layers" )
2713+ single_dit_layers = calculate_layers (state_dict_keys , key_prefix = "single_layers" )
2714+
2715+ # MMDiT blocks
2716+ for i in range (mmdit_layers ):
2717+ # Feed-forward
2718+ path_mapping = {"mlpX" : "ff" , "mlpC" : "ff_context" }
2719+ weight_mapping = {"c_fc1" : "linear_1" , "c_fc2" : "linear_2" , "c_proj" : "out_projection" }
2720+ for orig_k , diffuser_k in path_mapping .items ():
2721+ for k , v in weight_mapping .items ():
2722+ converted_state_dict [f"joint_transformer_blocks.{ i } .{ diffuser_k } .{ v } .weight" ] = checkpoint .pop (
2723+ f"double_layers.{ i } .{ orig_k } .{ k } .weight" , None
2724+ )
2725+
2726+ # Norms
2727+ path_mapping = {"modX" : "norm1" , "modC" : "norm1_context" }
2728+ for orig_k , diffuser_k in path_mapping .items ():
2729+ converted_state_dict [f"joint_transformer_blocks.{ i } .{ diffuser_k } .linear.weight" ] = checkpoint .pop (
2730+ f"double_layers.{ i } .{ orig_k } .1.weight" , None
2731+ )
2732+
2733+ # Attentions
2734+ x_attn_mapping = {"w2q" : "to_q" , "w2k" : "to_k" , "w2v" : "to_v" , "w2o" : "to_out.0" }
2735+ context_attn_mapping = {"w1q" : "add_q_proj" , "w1k" : "add_k_proj" , "w1v" : "add_v_proj" , "w1o" : "to_add_out" }
2736+ for attn_mapping in [x_attn_mapping , context_attn_mapping ]:
2737+ for k , v in attn_mapping .items ():
2738+ converted_state_dict [f"joint_transformer_blocks.{ i } .attn.{ v } .weight" ] = checkpoint .pop (
2739+ f"double_layers.{ i } .attn.{ k } .weight" , None
2740+ )
2741+
2742+ # Single-DiT blocks
2743+ for i in range (single_dit_layers ):
2744+ # Feed-forward
2745+ mapping = {"c_fc1" : "linear_1" , "c_fc2" : "linear_2" , "c_proj" : "out_projection" }
2746+ for k , v in mapping .items ():
2747+ converted_state_dict [f"single_transformer_blocks.{ i } .ff.{ v } .weight" ] = checkpoint .pop (
2748+ f"single_layers.{ i } .mlp.{ k } .weight" , None
2749+ )
2750+
2751+ # Norms
2752+ converted_state_dict [f"single_transformer_blocks.{ i } .norm1.linear.weight" ] = checkpoint .pop (
2753+ f"single_layers.{ i } .modCX.1.weight" , None
2754+ )
2755+
2756+ # Attentions
2757+ x_attn_mapping = {"w1q" : "to_q" , "w1k" : "to_k" , "w1v" : "to_v" , "w1o" : "to_out.0" }
2758+ for k , v in x_attn_mapping .items ():
2759+ converted_state_dict [f"single_transformer_blocks.{ i } .attn.{ v } .weight" ] = checkpoint .pop (
2760+ f"single_layers.{ i } .attn.{ k } .weight" , None
2761+ )
2762+ # Final blocks
2763+ converted_state_dict ["proj_out.weight" ] = checkpoint .pop ("final_linear.weight" , None )
2764+
2765+ # Handle the final norm layer
2766+ norm_weight = checkpoint .pop ("modF.1.weight" , None )
2767+ if norm_weight is not None :
2768+ converted_state_dict ["norm_out.linear.weight" ] = swap_scale_shift (norm_weight , dim = None )
2769+ else :
2770+ converted_state_dict ["norm_out.linear.weight" ] = None
2771+
2772+ converted_state_dict ["pos_embed.pos_embed" ] = checkpoint .pop ("positional_encoding" )
2773+ converted_state_dict ["pos_embed.proj.weight" ] = checkpoint .pop ("init_x_linear.weight" )
2774+ converted_state_dict ["pos_embed.proj.bias" ] = checkpoint .pop ("init_x_linear.bias" )
2775+
2776+ return converted_state_dict
0 commit comments