@@ -1358,14 +1358,30 @@ def load_lora_into_transformer(
13581358 inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name , ** peft_kwargs )
13591359 incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name , ** peft_kwargs )
13601360
1361+ warn_msg = ""
13611362 if incompatible_keys is not None :
1362- # check only for unexpected keys
1363+ # Check only for unexpected keys.
13631364 unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
13641365 if unexpected_keys :
1365- logger .warning (
1366- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1367- f" { unexpected_keys } . "
1368- )
1366+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
1367+ if lora_unexpected_keys :
1368+ warn_msg = (
1369+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1370+ f" { ', ' .join (lora_unexpected_keys )} . "
1371+ )
1372+
1373+ # Filter missing keys specific to the current adapter.
1374+ missing_keys = getattr (incompatible_keys , "missing_keys" , None )
1375+ if missing_keys :
1376+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
1377+ if lora_missing_keys :
1378+ warn_msg += (
1379+ f"Loading adapter weights from state_dict led to missing keys in the model:"
1380+ f" { ', ' .join (lora_missing_keys )} ."
1381+ )
1382+
1383+ if warn_msg :
1384+ logger .warning (warn_msg )
13691385
13701386 # Offload back.
13711387 if is_model_cpu_offload :
@@ -1932,14 +1948,30 @@ def load_lora_into_transformer(
19321948 inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name , ** peft_kwargs )
19331949 incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name , ** peft_kwargs )
19341950
1951+ warn_msg = ""
19351952 if incompatible_keys is not None :
1936- # check only for unexpected keys
1953+ # Check only for unexpected keys.
19371954 unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
19381955 if unexpected_keys :
1939- logger .warning (
1940- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1941- f" { unexpected_keys } . "
1942- )
1956+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
1957+ if lora_unexpected_keys :
1958+ warn_msg = (
1959+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
1960+ f" { ', ' .join (lora_unexpected_keys )} . "
1961+ )
1962+
1963+ # Filter missing keys specific to the current adapter.
1964+ missing_keys = getattr (incompatible_keys , "missing_keys" , None )
1965+ if missing_keys :
1966+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
1967+ if lora_missing_keys :
1968+ warn_msg += (
1969+ f"Loading adapter weights from state_dict led to missing keys in the model:"
1970+ f" { ', ' .join (lora_missing_keys )} ."
1971+ )
1972+
1973+ if warn_msg :
1974+ logger .warning (warn_msg )
19431975
19441976 # Offload back.
19451977 if is_model_cpu_offload :
@@ -2279,14 +2311,30 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada
22792311 inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name )
22802312 incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name )
22812313
2314+ warn_msg = ""
22822315 if incompatible_keys is not None :
2283- # check only for unexpected keys
2316+ # Check only for unexpected keys.
22842317 unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
22852318 if unexpected_keys :
2286- logger .warning (
2287- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
2288- f" { unexpected_keys } . "
2289- )
2319+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
2320+ if lora_unexpected_keys :
2321+ warn_msg = (
2322+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2323+ f" { ', ' .join (lora_unexpected_keys )} . "
2324+ )
2325+
2326+ # Filter missing keys specific to the current adapter.
2327+ missing_keys = getattr (incompatible_keys , "missing_keys" , None )
2328+ if missing_keys :
2329+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
2330+ if lora_missing_keys :
2331+ warn_msg += (
2332+ f"Loading adapter weights from state_dict led to missing keys in the model:"
2333+ f" { ', ' .join (lora_missing_keys )} ."
2334+ )
2335+
2336+ if warn_msg :
2337+ logger .warning (warn_msg )
22902338
22912339 # Offload back.
22922340 if is_model_cpu_offload :
@@ -2717,14 +2765,30 @@ def load_lora_into_transformer(
27172765 inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name , ** peft_kwargs )
27182766 incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name , ** peft_kwargs )
27192767
2768+ warn_msg = ""
27202769 if incompatible_keys is not None :
2721- # check only for unexpected keys
2770+ # Check only for unexpected keys.
27222771 unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
27232772 if unexpected_keys :
2724- logger .warning (
2725- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
2726- f" { unexpected_keys } . "
2727- )
2773+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k ]
2774+ if lora_unexpected_keys :
2775+ warn_msg = (
2776+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
2777+ f" { ', ' .join (lora_unexpected_keys )} . "
2778+ )
2779+
2780+ # Filter missing keys specific to the current adapter.
2781+ missing_keys = getattr (incompatible_keys , "missing_keys" , None )
2782+ if missing_keys :
2783+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k ]
2784+ if lora_missing_keys :
2785+ warn_msg += (
2786+ f"Loading adapter weights from state_dict led to missing keys in the model:"
2787+ f" { ', ' .join (lora_missing_keys )} ."
2788+ )
2789+
2790+ if warn_msg :
2791+ logger .warning (warn_msg )
27282792
27292793 # Offload back.
27302794 if is_model_cpu_offload :
0 commit comments