@@ -1358,14 +1358,30 @@ def load_lora_into_transformer(
13581358            inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name , ** peft_kwargs )
13591359            incompatible_keys  =  set_peft_model_state_dict (transformer , state_dict , adapter_name , ** peft_kwargs )
13601360
1361+             warn_msg  =  "" 
13611362            if  incompatible_keys  is  not None :
1362-                 # check  only for unexpected keys 
1363+                 # Check  only for unexpected keys.  
13631364                unexpected_keys  =  getattr (incompatible_keys , "unexpected_keys" , None )
13641365                if  unexpected_keys :
1365-                     logger .warning (
1366-                         f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 
1367-                         f" { unexpected_keys }  
1368-                     )
1366+                     lora_unexpected_keys  =  [k  for  k  in  unexpected_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
1367+                     if  lora_unexpected_keys :
1368+                         warn_msg  =  (
1369+                             f"Loading adapter weights from state_dict led to unexpected keys found in the model:" 
1370+                             f" { ', ' .join (lora_unexpected_keys )}  
1371+                         )
1372+ 
1373+                 # Filter missing keys specific to the current adapter. 
1374+                 missing_keys  =  getattr (incompatible_keys , "missing_keys" , None )
1375+                 if  missing_keys :
1376+                     lora_missing_keys  =  [k  for  k  in  missing_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
1377+                     if  lora_missing_keys :
1378+                         warn_msg  +=  (
1379+                             f"Loading adapter weights from state_dict led to missing keys in the model:" 
1380+                             f" { ', ' .join (lora_missing_keys )}  
1381+                         )
1382+ 
1383+             if  warn_msg :
1384+                 logger .warning (warn_msg )
13691385
13701386            # Offload back. 
13711387            if  is_model_cpu_offload :
@@ -1932,14 +1948,30 @@ def load_lora_into_transformer(
19321948            inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name , ** peft_kwargs )
19331949            incompatible_keys  =  set_peft_model_state_dict (transformer , state_dict , adapter_name , ** peft_kwargs )
19341950
1951+             warn_msg  =  "" 
19351952            if  incompatible_keys  is  not None :
1936-                 # check  only for unexpected keys 
1953+                 # Check  only for unexpected keys.  
19371954                unexpected_keys  =  getattr (incompatible_keys , "unexpected_keys" , None )
19381955                if  unexpected_keys :
1939-                     logger .warning (
1940-                         f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 
1941-                         f" { unexpected_keys }  
1942-                     )
1956+                     lora_unexpected_keys  =  [k  for  k  in  unexpected_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
1957+                     if  lora_unexpected_keys :
1958+                         warn_msg  =  (
1959+                             f"Loading adapter weights from state_dict led to unexpected keys found in the model:" 
1960+                             f" { ', ' .join (lora_unexpected_keys )}  
1961+                         )
1962+ 
1963+                 # Filter missing keys specific to the current adapter. 
1964+                 missing_keys  =  getattr (incompatible_keys , "missing_keys" , None )
1965+                 if  missing_keys :
1966+                     lora_missing_keys  =  [k  for  k  in  missing_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
1967+                     if  lora_missing_keys :
1968+                         warn_msg  +=  (
1969+                             f"Loading adapter weights from state_dict led to missing keys in the model:" 
1970+                             f" { ', ' .join (lora_missing_keys )}  
1971+                         )
1972+ 
1973+             if  warn_msg :
1974+                 logger .warning (warn_msg )
19431975
19441976            # Offload back. 
19451977            if  is_model_cpu_offload :
@@ -2279,14 +2311,30 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
22792311            inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name )
22802312            incompatible_keys  =  set_peft_model_state_dict (transformer , state_dict , adapter_name )
22812313
2314+             warn_msg  =  "" 
22822315            if  incompatible_keys  is  not None :
2283-                 # check  only for unexpected keys 
2316+                 # Check  only for unexpected keys.  
22842317                unexpected_keys  =  getattr (incompatible_keys , "unexpected_keys" , None )
22852318                if  unexpected_keys :
2286-                     logger .warning (
2287-                         f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 
2288-                         f" { unexpected_keys }  
2289-                     )
2319+                     lora_unexpected_keys  =  [k  for  k  in  unexpected_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
2320+                     if  lora_unexpected_keys :
2321+                         warn_msg  =  (
2322+                             f"Loading adapter weights from state_dict led to unexpected keys found in the model:" 
2323+                             f" { ', ' .join (lora_unexpected_keys )}  
2324+                         )
2325+ 
2326+                 # Filter missing keys specific to the current adapter. 
2327+                 missing_keys  =  getattr (incompatible_keys , "missing_keys" , None )
2328+                 if  missing_keys :
2329+                     lora_missing_keys  =  [k  for  k  in  missing_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
2330+                     if  lora_missing_keys :
2331+                         warn_msg  +=  (
2332+                             f"Loading adapter weights from state_dict led to missing keys in the model:" 
2333+                             f" { ', ' .join (lora_missing_keys )}  
2334+                         )
2335+ 
2336+             if  warn_msg :
2337+                 logger .warning (warn_msg )
22902338
22912339            # Offload back. 
22922340            if  is_model_cpu_offload :
@@ -2717,14 +2765,30 @@ def load_lora_into_transformer(
27172765            inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name , ** peft_kwargs )
27182766            incompatible_keys  =  set_peft_model_state_dict (transformer , state_dict , adapter_name , ** peft_kwargs )
27192767
2768+             warn_msg  =  "" 
27202769            if  incompatible_keys  is  not None :
2721-                 # check  only for unexpected keys 
2770+                 # Check  only for unexpected keys.  
27222771                unexpected_keys  =  getattr (incompatible_keys , "unexpected_keys" , None )
27232772                if  unexpected_keys :
2724-                     logger .warning (
2725-                         f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 
2726-                         f" { unexpected_keys }  
2727-                     )
2773+                     lora_unexpected_keys  =  [k  for  k  in  unexpected_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
2774+                     if  lora_unexpected_keys :
2775+                         warn_msg  =  (
2776+                             f"Loading adapter weights from state_dict led to unexpected keys found in the model:" 
2777+                             f" { ', ' .join (lora_unexpected_keys )}  
2778+                         )
2779+ 
2780+                 # Filter missing keys specific to the current adapter. 
2781+                 missing_keys  =  getattr (incompatible_keys , "missing_keys" , None )
2782+                 if  missing_keys :
2783+                     lora_missing_keys  =  [k  for  k  in  missing_keys  if  "lora_"  in  k  and  adapter_name  in  k ]
2784+                     if  lora_missing_keys :
2785+                         warn_msg  +=  (
2786+                             f"Loading adapter weights from state_dict led to missing keys in the model:" 
2787+                             f" { ', ' .join (lora_missing_keys )}  
2788+                         )
2789+ 
2790+             if  warn_msg :
2791+                 logger .warning (warn_msg )
27282792
27292793            # Offload back. 
27302794            if  is_model_cpu_offload :
0 commit comments