@@ -516,10 +516,47 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
516516 f"transformer.single_transformer_blocks.{ i } .norm.linear" ,
517517 )
518518
519+ remaining_keys = list (sds_sd .keys ())
520+ te_state_dict = {}
521+ if remaining_keys :
522+ if not all (k .startswith ("lora_te1" ) for k in remaining_keys ):
523+ raise ValueError (f"Incompatible keys detected: \n \n { ', ' .join (remaining_keys )} " )
524+ for key in remaining_keys :
525+ if not key .endswith ("lora_down.weight" ):
526+ continue
527+
528+ lora_name = key .split ("." )[0 ]
529+ lora_name_up = f"{ lora_name } .lora_up.weight"
530+ lora_name_alpha = f"{ lora_name } .alpha"
531+ diffusers_name = _convert_text_encoder_lora_key (key , lora_name )
532+
533+ if lora_name .startswith (("lora_te_" , "lora_te1_" )):
534+ down_weight = sds_sd .pop (key )
535+ sd_lora_rank = down_weight .shape [0 ]
536+ te_state_dict [diffusers_name ] = down_weight
537+ te_state_dict [diffusers_name .replace (".down." , ".up." )] = sds_sd .pop (lora_name_up )
538+
539+ if lora_name_alpha in sds_sd :
540+ alpha = sds_sd .pop (lora_name_alpha ).item ()
541+ scale = alpha / sd_lora_rank
542+
543+ scale_down = scale
544+ scale_up = 1.0
545+ while scale_down * 2 < scale_up :
546+ scale_down *= 2
547+ scale_up /= 2
548+
549+ te_state_dict [diffusers_name ] *= scale_down
550+ te_state_dict [diffusers_name .replace (".down." , ".up." )] *= scale_up
551+
519552 if len (sds_sd ) > 0 :
520- logger .warning (f"Unsuppored keys for ai-toolkit: { sds_sd .keys ()} " )
553+ logger .warning (f"Unsupported keys for ai-toolkit: { sds_sd .keys ()} " )
554+
555+ if te_state_dict :
556+ te_state_dict = {f"text_encoder.{ module_name } " : params for module_name , params in te_state_dict .items ()}
521557
522- return ait_sd
558+ new_state_dict = {** ait_sd , ** te_state_dict }
559+ return new_state_dict
523560
524561 return _convert_sd_scripts_to_ai_toolkit (state_dict )
525562
0 commit comments