@@ -1276,3 +1276,74 @@ def remap_single_transformer_blocks_(key, state_dict):
12761276 converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
12771277
12781278 return converted_state_dict
1279+
1280+
1281+ def _convert_non_diffusers_lumina2_lora_to_diffusers (state_dict ):
1282+ # Remove "diffusion_model." prefix from keys.
1283+ state_dict = {k [len ("diffusion_model." ) :]: v for k , v in state_dict .items ()}
1284+ converted_state_dict = {}
1285+
1286+ def get_num_layers (keys , pattern ):
1287+ layers = set ()
1288+ for key in keys :
1289+ match = re .search (pattern , key )
1290+ if match :
1291+ layers .add (int (match .group (1 )))
1292+ return len (layers )
1293+
1294+ def process_block (prefix , index , convert_norm ):
1295+ # Process attention qkv: pop lora_A and lora_B weights.
1296+ lora_down = state_dict .pop (f"{ prefix } .{ index } .attention.qkv.lora_A.weight" )
1297+ lora_up = state_dict .pop (f"{ prefix } .{ index } .attention.qkv.lora_B.weight" )
1298+ for attn_key in ["to_q" , "to_k" , "to_v" ]:
1299+ converted_state_dict [f"{ prefix } .{ index } .attn.{ attn_key } .lora_A.weight" ] = lora_down
1300+ for attn_key , weight in zip (["to_q" , "to_k" , "to_v" ], torch .split (lora_up , [2304 , 768 , 768 ], dim = 0 )):
1301+ converted_state_dict [f"{ prefix } .{ index } .attn.{ attn_key } .lora_B.weight" ] = weight
1302+
1303+ # Process attention out weights.
1304+ converted_state_dict [f"{ prefix } .{ index } .attn.to_out.0.lora_A.weight" ] = state_dict .pop (
1305+ f"{ prefix } .{ index } .attention.out.lora_A.weight"
1306+ )
1307+ converted_state_dict [f"{ prefix } .{ index } .attn.to_out.0.lora_B.weight" ] = state_dict .pop (
1308+ f"{ prefix } .{ index } .attention.out.lora_B.weight"
1309+ )
1310+
1311+ # Process feed-forward weights for layers 1, 2, and 3.
1312+ for layer in range (1 , 4 ):
1313+ converted_state_dict [f"{ prefix } .{ index } .feed_forward.linear_{ layer } .lora_A.weight" ] = state_dict .pop (
1314+ f"{ prefix } .{ index } .feed_forward.w{ layer } .lora_A.weight"
1315+ )
1316+ converted_state_dict [f"{ prefix } .{ index } .feed_forward.linear_{ layer } .lora_B.weight" ] = state_dict .pop (
1317+ f"{ prefix } .{ index } .feed_forward.w{ layer } .lora_B.weight"
1318+ )
1319+
1320+ if convert_norm :
1321+ converted_state_dict [f"{ prefix } .{ index } .norm1.linear.lora_A.weight" ] = state_dict .pop (
1322+ f"{ prefix } .{ index } .adaLN_modulation.1.lora_A.weight"
1323+ )
1324+ converted_state_dict [f"{ prefix } .{ index } .norm1.linear.lora_B.weight" ] = state_dict .pop (
1325+ f"{ prefix } .{ index } .adaLN_modulation.1.lora_B.weight"
1326+ )
1327+
1328+ noise_refiner_pattern = r"noise_refiner\.(\d+)\."
1329+ num_noise_refiner_layers = get_num_layers (state_dict .keys (), noise_refiner_pattern )
1330+ for i in range (num_noise_refiner_layers ):
1331+ process_block ("noise_refiner" , i , convert_norm = True )
1332+
1333+ context_refiner_pattern = r"context_refiner\.(\d+)\."
1334+ num_context_refiner_layers = get_num_layers (state_dict .keys (), context_refiner_pattern )
1335+ for i in range (num_context_refiner_layers ):
1336+ process_block ("context_refiner" , i , convert_norm = False )
1337+
1338+ core_transformer_pattern = r"layers\.(\d+)\."
1339+ num_core_transformer_layers = get_num_layers (state_dict .keys (), core_transformer_pattern )
1340+ for i in range (num_core_transformer_layers ):
1341+ process_block ("layers" , i , convert_norm = True )
1342+
1343+ if len (state_dict ) > 0 :
1344+ raise ValueError (f"`state_dict` should be empty at this point but has { state_dict .keys ()= } " )
1345+
1346+ for key in list (converted_state_dict .keys ()):
1347+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1348+
1349+ return converted_state_dict
0 commit comments