@@ -2137,9 +2137,18 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
21372137    converted_state_dict  =  {}
21382138    keys  =  list (checkpoint .keys ())
21392139
2140+     variant  =  "chroma"  if  "distilled_guidance_layer.in_proj.weight"  in  checkpoint  else  "flux" 
2141+ 
21402142    for  k  in  keys :
21412143        if  "model.diffusion_model."  in  k :
21422144            checkpoint [k .replace ("model.diffusion_model." , "" )] =  checkpoint .pop (k )
2145+         if  variant  ==  "chroma"  and  "distilled_guidance_layer."  in  k :
2146+             new_key  =  k 
2147+             if  k .startswith ("distilled_guidance_layer.norms" ):
2148+                 new_key  =  k .replace (".scale" , ".weight" )
2149+             elif  k .startswith ("distilled_guidance_layer.layer" ):
2150+                 new_key  =  k .replace ("in_layer" , "linear_1" ).replace ("out_layer" , "linear_2" )
2151+             converted_state_dict [new_key ] =  checkpoint .pop (k )
21432152
21442153    num_layers  =  list (set (int (k .split ("." , 2 )[1 ]) for  k  in  checkpoint  if  "double_blocks."  in  k ))[- 1 ] +  1   # noqa: C401 
21452154    num_single_layers  =  list (set (int (k .split ("." , 2 )[1 ]) for  k  in  checkpoint  if  "single_blocks."  in  k ))[- 1 ] +  1   # noqa: C401 
@@ -2153,40 +2162,49 @@ def swap_scale_shift(weight):
21532162        new_weight  =  torch .cat ([scale , shift ], dim = 0 )
21542163        return  new_weight 
21552164
2156-     ## time_text_embed.timestep_embedder <-  time_in 
2157-     converted_state_dict ["time_text_embed.timestep_embedder.linear_1.weight" ] =  checkpoint .pop (
2158-         "time_in.in_layer.weight" 
2159-     )
2160-     converted_state_dict ["time_text_embed.timestep_embedder.linear_1.bias" ] =  checkpoint .pop ("time_in.in_layer.bias" )
2161-     converted_state_dict ["time_text_embed.timestep_embedder.linear_2.weight" ] =  checkpoint .pop (
2162-         "time_in.out_layer.weight" 
2163-     )
2164-     converted_state_dict ["time_text_embed.timestep_embedder.linear_2.bias" ] =  checkpoint .pop ("time_in.out_layer.bias" )
2165- 
2166-     ## time_text_embed.text_embedder <- vector_in 
2167-     converted_state_dict ["time_text_embed.text_embedder.linear_1.weight" ] =  checkpoint .pop ("vector_in.in_layer.weight" )
2168-     converted_state_dict ["time_text_embed.text_embedder.linear_1.bias" ] =  checkpoint .pop ("vector_in.in_layer.bias" )
2169-     converted_state_dict ["time_text_embed.text_embedder.linear_2.weight" ] =  checkpoint .pop (
2170-         "vector_in.out_layer.weight" 
2171-     )
2172-     converted_state_dict ["time_text_embed.text_embedder.linear_2.bias" ] =  checkpoint .pop ("vector_in.out_layer.bias" )
2173- 
2174-     # guidance 
2175-     has_guidance  =  any ("guidance"  in  k  for  k  in  checkpoint )
2176-     if  has_guidance :
2177-         converted_state_dict ["time_text_embed.guidance_embedder.linear_1.weight" ] =  checkpoint .pop (
2178-             "guidance_in.in_layer.weight" 
2165+     if  variant  ==  "flux" :
2166+         ## time_text_embed.timestep_embedder <-  time_in 
2167+         converted_state_dict ["time_text_embed.timestep_embedder.linear_1.weight" ] =  checkpoint .pop (
2168+             "time_in.in_layer.weight" 
21792169        )
2180-         converted_state_dict ["time_text_embed.guidance_embedder .linear_1.bias" ] =  checkpoint .pop (
2181-             "guidance_in .in_layer.bias" 
2170+         converted_state_dict ["time_text_embed.timestep_embedder .linear_1.bias" ] =  checkpoint .pop (
2171+             "time_in .in_layer.bias" 
21822172        )
2183-         converted_state_dict ["time_text_embed.guidance_embedder .linear_2.weight" ] =  checkpoint .pop (
2184-             "guidance_in .out_layer.weight" 
2173+         converted_state_dict ["time_text_embed.timestep_embedder .linear_2.weight" ] =  checkpoint .pop (
2174+             "time_in .out_layer.weight" 
21852175        )
2186-         converted_state_dict ["time_text_embed.guidance_embedder .linear_2.bias" ] =  checkpoint .pop (
2187-             "guidance_in .out_layer.bias" 
2176+         converted_state_dict ["time_text_embed.timestep_embedder .linear_2.bias" ] =  checkpoint .pop (
2177+             "time_in .out_layer.bias" 
21882178        )
21892179
2180+         ## time_text_embed.text_embedder <- vector_in 
2181+         converted_state_dict ["time_text_embed.text_embedder.linear_1.weight" ] =  checkpoint .pop (
2182+             "vector_in.in_layer.weight" 
2183+         )
2184+         converted_state_dict ["time_text_embed.text_embedder.linear_1.bias" ] =  checkpoint .pop ("vector_in.in_layer.bias" )
2185+         converted_state_dict ["time_text_embed.text_embedder.linear_2.weight" ] =  checkpoint .pop (
2186+             "vector_in.out_layer.weight" 
2187+         )
2188+         converted_state_dict ["time_text_embed.text_embedder.linear_2.bias" ] =  checkpoint .pop (
2189+             "vector_in.out_layer.bias" 
2190+         )
2191+ 
2192+         # guidance 
2193+         has_guidance  =  any ("guidance"  in  k  for  k  in  checkpoint )
2194+         if  has_guidance :
2195+             converted_state_dict ["time_text_embed.guidance_embedder.linear_1.weight" ] =  checkpoint .pop (
2196+                 "guidance_in.in_layer.weight" 
2197+             )
2198+             converted_state_dict ["time_text_embed.guidance_embedder.linear_1.bias" ] =  checkpoint .pop (
2199+                 "guidance_in.in_layer.bias" 
2200+             )
2201+             converted_state_dict ["time_text_embed.guidance_embedder.linear_2.weight" ] =  checkpoint .pop (
2202+                 "guidance_in.out_layer.weight" 
2203+             )
2204+             converted_state_dict ["time_text_embed.guidance_embedder.linear_2.bias" ] =  checkpoint .pop (
2205+                 "guidance_in.out_layer.bias" 
2206+             )
2207+ 
21902208    # context_embedder 
21912209    converted_state_dict ["context_embedder.weight" ] =  checkpoint .pop ("txt_in.weight" )
21922210    converted_state_dict ["context_embedder.bias" ] =  checkpoint .pop ("txt_in.bias" )
@@ -2199,20 +2217,21 @@ def swap_scale_shift(weight):
21992217    for  i  in  range (num_layers ):
22002218        block_prefix  =  f"transformer_blocks.{ i }  
22012219        # norms. 
2202-         ## norm1 
2203-         converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2204-             f"double_blocks.{ i }  
2205-         )
2206-         converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2207-             f"double_blocks.{ i }  
2208-         )
2209-         ## norm1_context 
2210-         converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2211-             f"double_blocks.{ i }  
2212-         )
2213-         converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2214-             f"double_blocks.{ i }  
2215-         )
2220+         if  variant  ==  "flux" :
2221+             ## norm1 
2222+             converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2223+                 f"double_blocks.{ i }  
2224+             )
2225+             converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2226+                 f"double_blocks.{ i }  
2227+             )
2228+             ## norm1_context 
2229+             converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2230+                 f"double_blocks.{ i }  
2231+             )
2232+             converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2233+                 f"double_blocks.{ i }  
2234+             )
22162235        # Q, K, V 
22172236        sample_q , sample_k , sample_v  =  torch .chunk (checkpoint .pop (f"double_blocks.{ i }  ), 3 , dim = 0 )
22182237        context_q , context_k , context_v  =  torch .chunk (
@@ -2285,13 +2304,15 @@ def swap_scale_shift(weight):
22852304    # single transformer blocks 
22862305    for  i  in  range (num_single_layers ):
22872306        block_prefix  =  f"single_transformer_blocks.{ i }  
2288-         # norm.linear  <- single_blocks.0.modulation.lin 
2289-         converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2290-             f"single_blocks.{ i }  
2291-         )
2292-         converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2293-             f"single_blocks.{ i }  
2294-         )
2307+ 
2308+         if  variant  ==  "flux" :
2309+             # norm.linear  <- single_blocks.0.modulation.lin 
2310+             converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2311+                 f"single_blocks.{ i }  
2312+             )
2313+             converted_state_dict [f"{ block_prefix }  ] =  checkpoint .pop (
2314+                 f"single_blocks.{ i }  
2315+             )
22952316        # Q, K, V, mlp 
22962317        mlp_hidden_dim  =  int (inner_dim  *  mlp_ratio )
22972318        split_size  =  (inner_dim , inner_dim , inner_dim , mlp_hidden_dim )
@@ -2320,12 +2341,14 @@ def swap_scale_shift(weight):
23202341
23212342    converted_state_dict ["proj_out.weight" ] =  checkpoint .pop ("final_layer.linear.weight" )
23222343    converted_state_dict ["proj_out.bias" ] =  checkpoint .pop ("final_layer.linear.bias" )
2323-     converted_state_dict ["norm_out.linear.weight" ] =  swap_scale_shift (
2324-         checkpoint .pop ("final_layer.adaLN_modulation.1.weight" )
2325-     )
2326-     converted_state_dict ["norm_out.linear.bias" ] =  swap_scale_shift (
2327-         checkpoint .pop ("final_layer.adaLN_modulation.1.bias" )
2328-     )
2344+ 
2345+     if  variant  ==  "flux" :
2346+         converted_state_dict ["norm_out.linear.weight" ] =  swap_scale_shift (
2347+             checkpoint .pop ("final_layer.adaLN_modulation.1.weight" )
2348+         )
2349+         converted_state_dict ["norm_out.linear.bias" ] =  swap_scale_shift (
2350+             checkpoint .pop ("final_layer.adaLN_modulation.1.bias" )
2351+         )
23292352
23302353    return  converted_state_dict 
23312354
0 commit comments