@@ -1276,3 +1276,74 @@ def remap_single_transformer_blocks_(key, state_dict):
12761276        converted_state_dict [f"transformer.{ key }  ] =  converted_state_dict .pop (key )
12771277
12781278    return  converted_state_dict 
1279+ 
1280+ 
1281+ def  _convert_non_diffusers_lumina2_lora_to_diffusers (state_dict ):
1282+     # Remove "diffusion_model." prefix from keys. 
1283+     state_dict  =  {k [len ("diffusion_model." ) :]: v  for  k , v  in  state_dict .items ()}
1284+     converted_state_dict  =  {}
1285+ 
1286+     def  get_num_layers (keys , pattern ):
1287+         layers  =  set ()
1288+         for  key  in  keys :
1289+             match  =  re .search (pattern , key )
1290+             if  match :
1291+                 layers .add (int (match .group (1 )))
1292+         return  len (layers )
1293+ 
1294+     def  process_block (prefix , index , convert_norm ):
1295+         # Process attention qkv: pop lora_A and lora_B weights. 
1296+         lora_down  =  state_dict .pop (f"{ prefix } { index }  )
1297+         lora_up  =  state_dict .pop (f"{ prefix } { index }  )
1298+         for  attn_key  in  ["to_q" , "to_k" , "to_v" ]:
1299+             converted_state_dict [f"{ prefix } { index } { attn_key }  ] =  lora_down 
1300+         for  attn_key , weight  in  zip (["to_q" , "to_k" , "to_v" ], torch .split (lora_up , [2304 , 768 , 768 ], dim = 0 )):
1301+             converted_state_dict [f"{ prefix } { index } { attn_key }  ] =  weight 
1302+ 
1303+         # Process attention out weights. 
1304+         converted_state_dict [f"{ prefix } { index }  ] =  state_dict .pop (
1305+             f"{ prefix } { index }  
1306+         )
1307+         converted_state_dict [f"{ prefix } { index }  ] =  state_dict .pop (
1308+             f"{ prefix } { index }  
1309+         )
1310+ 
1311+         # Process feed-forward weights for layers 1, 2, and 3. 
1312+         for  layer  in  range (1 , 4 ):
1313+             converted_state_dict [f"{ prefix } { index } { layer }  ] =  state_dict .pop (
1314+                 f"{ prefix } { index } { layer }  
1315+             )
1316+             converted_state_dict [f"{ prefix } { index } { layer }  ] =  state_dict .pop (
1317+                 f"{ prefix } { index } { layer }  
1318+             )
1319+ 
1320+         if  convert_norm :
1321+             converted_state_dict [f"{ prefix } { index }  ] =  state_dict .pop (
1322+                 f"{ prefix } { index }  
1323+             )
1324+             converted_state_dict [f"{ prefix } { index }  ] =  state_dict .pop (
1325+                 f"{ prefix } { index }  
1326+             )
1327+ 
1328+     noise_refiner_pattern  =  r"noise_refiner\.(\d+)\." 
1329+     num_noise_refiner_layers  =  get_num_layers (state_dict .keys (), noise_refiner_pattern )
1330+     for  i  in  range (num_noise_refiner_layers ):
1331+         process_block ("noise_refiner" , i , convert_norm = True )
1332+ 
1333+     context_refiner_pattern  =  r"context_refiner\.(\d+)\." 
1334+     num_context_refiner_layers  =  get_num_layers (state_dict .keys (), context_refiner_pattern )
1335+     for  i  in  range (num_context_refiner_layers ):
1336+         process_block ("context_refiner" , i , convert_norm = False )
1337+ 
1338+     core_transformer_pattern  =  r"layers\.(\d+)\." 
1339+     num_core_transformer_layers  =  get_num_layers (state_dict .keys (), core_transformer_pattern )
1340+     for  i  in  range (num_core_transformer_layers ):
1341+         process_block ("layers" , i , convert_norm = True )
1342+ 
1343+     if  len (state_dict ) >  0 :
1344+         raise  ValueError (f"`state_dict` should be empty at this point but has { state_dict .keys ()= }  )
1345+ 
1346+     for  key  in  list (converted_state_dict .keys ()):
1347+         converted_state_dict [f"transformer.{ key }  ] =  converted_state_dict .pop (key )
1348+ 
1349+     return  converted_state_dict 
0 commit comments