@@ -558,70 +558,62 @@ def assign_remaining_weights(assignments, source):
558558                    ait_sd [target_key ] =  value 
559559
560560        if  any ("guidance_in"  in  k  for  k  in  sds_sd ):
561-             assign_remaining_weights (
562-                 [
563-                     (
564-                         "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight" ,
565-                         "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight" ,
566-                         None ,
567-                     ),
568-                     (
569-                         "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight" ,
570-                         "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight" ,
571-                         None ,
572-                     ),
573-                 ],
561+             _convert_to_ai_toolkit (
574562                sds_sd ,
563+                 ait_sd ,
564+                 "lora_unet_guidance_in_in_layer" ,
565+                 "time_text_embed.guidance_embedder.linear_1" ,
566+             )
567+ 
568+             _convert_to_ai_toolkit (
569+                 sds_sd ,
570+                 ait_sd ,
571+                 "lora_unet_guidance_in_out_layer" ,
572+                 "time_text_embed.guidance_embedder.linear_2" ,
575573            )
576574
577575        if  any ("img_in"  in  k  for  k  in  sds_sd ):
578-             assign_remaining_weights (
579-                 [
580-                     ("x_embedder.{lora_key}.weight" , "lora_unet_img_in.{orig_lora_key}.weight" , None ),
581-                 ],
576+             _convert_to_ai_toolkit (
582577                sds_sd ,
578+                 ait_sd ,
579+                 "lora_unet_img_in" ,
580+                 "x_embedder" ,
583581            )
584582
585583        if  any ("txt_in"  in  k  for  k  in  sds_sd ):
586-             assign_remaining_weights (
587-                 [
588-                     ("context_embedder.{lora_key}.weight" , "lora_unet_txt_in.{orig_lora_key}.weight" , None ),
589-                 ],
584+             _convert_to_ai_toolkit (
590585                sds_sd ,
586+                 ait_sd ,
587+                 "lora_unet_txt_in" ,
588+                 "context_embedder" ,
591589            )
592590
593591        if  any ("time_in"  in  k  for  k  in  sds_sd ):
594-             assign_remaining_weights (
595-                 [
596-                     (
597-                         "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight" ,
598-                         "lora_unet_time_in_in_layer.{orig_lora_key}.weight" ,
599-                         None ,
600-                     ),
601-                     (
602-                         "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight" ,
603-                         "lora_unet_time_in_out_layer.{orig_lora_key}.weight" ,
604-                         None ,
605-                     ),
606-                 ],
592+             _convert_to_ai_toolkit (
607593                sds_sd ,
594+                 ait_sd ,
595+                 "lora_unet_time_in_in_layer" ,
596+                 "time_text_embed.timestep_embedder.linear_1" ,
597+             )
598+             _convert_to_ai_toolkit (
599+                 sds_sd ,
600+                 ait_sd ,
601+                 "lora_unet_time_in_out_layer" ,
602+                 "time_text_embed.timestep_embedder.linear_2" ,
608603            )
609604
610605        if  any ("vector_in"  in  k  for  k  in  sds_sd ):
611-             assign_remaining_weights (
612-                 [
613-                     (
614-                         "time_text_embed.text_embedder.linear_1.{lora_key}.weight" ,
615-                         "lora_unet_vector_in_in_layer.{orig_lora_key}.weight" ,
616-                         None ,
617-                     ),
618-                     (
619-                         "time_text_embed.text_embedder.linear_2.{lora_key}.weight" ,
620-                         "lora_unet_vector_in_out_layer.{orig_lora_key}.weight" ,
621-                         None ,
622-                     ),
623-                 ],
606+             _convert_to_ai_toolkit (
607+                 sds_sd ,
608+                 ait_sd ,
609+                 "lora_unet_vector_in_in_layer" ,
610+                 "time_text_embed.text_embedder.linear_1" ,
611+             )
612+             _convert_to_ai_toolkit (
624613                sds_sd ,
614+                 ait_sd ,
615+                 "lora_unet_vector_in_out_layer" ,
616+                 "time_text_embed.text_embedder.linear_2" ,
625617            )
626618
627619        if  any ("final_layer"  in  k  for  k  in  sds_sd ):
0 commit comments