@@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516
516
f"transformer.single_transformer_blocks.{ i } .norm.linear" ,
517
517
)
518
518
519
+ remaining_keys = list (sds_sd .keys ())
520
+ te_state_dict = {}
521
+ if remaining_keys :
522
+ if not all (k .startswith ("lora_te1" ) for k in remaining_keys ):
523
+ raise ValueError (f"Incompatible keys detected: \n \n { ', ' .join (remaining_keys )} " )
524
+ for key in remaining_keys :
525
+ if not key .endswith ("lora_down.weight" ):
526
+ continue
527
+
528
+ lora_name = key .split ("." )[0 ]
529
+ lora_name_up = f"{ lora_name } .lora_up.weight"
530
+ lora_name_alpha = f"{ lora_name } .alpha"
531
+ diffusers_name = _convert_text_encoder_lora_key (key , lora_name )
532
+
533
+ if lora_name .startswith (("lora_te_" , "lora_te1_" )):
534
+ down_weight = sds_sd .pop (key )
535
+ sd_lora_rank = down_weight .shape [0 ]
536
+ te_state_dict [diffusers_name ] = down_weight
537
+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = sds_sd .pop (lora_name_up )
538
+
539
+ if lora_name_alpha in sds_sd :
540
+ alpha = sds_sd .pop (lora_name_alpha ).item ()
541
+ scale = alpha / sd_lora_rank
542
+
543
+ scale_down = scale
544
+ scale_up = 1.0
545
+ while scale_down * 2 < scale_up :
546
+ scale_down *= 2
547
+ scale_up /= 2
548
+
549
+ te_state_dict [diffusers_name ] *= scale_down
550
+ te_state_dict [diffusers_name .replace (".down." , ".up." )] *= scale_up
551
+
519
552
if len (sds_sd ) > 0 :
520
- logger .warning (f"Unsuppored keys for ai-toolkit: { sds_sd .keys ()} " )
553
+ logger .warning (f"Unsupported keys for ai-toolkit: { sds_sd .keys ()} " )
554
+
555
+ if te_state_dict :
556
+ te_state_dict = {f"text_encoder.{ module_name } " : params for module_name , params in te_state_dict .items ()}
521
557
522
- return ait_sd
558
+ new_state_dict = {** ait_sd , ** te_state_dict }
559
+ return new_state_dict
523
560
524
561
return _convert_sd_scripts_to_ai_toolkit (state_dict )
525
562
0 commit comments