@@ -2392,52 +2392,62 @@ def test_non_existing_model_card(self):
23922392 assert len (model_card ) > 1000
23932393
23942394 @pytest .mark .parametrize ("save_embedding_layers" , ["auto" , True , False ])
2395- def test_targeting_lora_to_embedding_layer (self , save_embedding_layers ):
2395+ @pytest .mark .parametrize (
2396+ "peft_config" ,
2397+ [
2398+ (LoraConfig (target_modules = ["lin0" , "embed_tokens" ], init_lora_weights = False )),
2399+ (LoraConfig (target_modules = r"^embed_tokens" , init_lora_weights = False )),
2400+ ],
2401+ )
2402+ def test_save_pretrained_targeting_lora_to_embedding_layer (self , save_embedding_layers , tmp_path , peft_config ):
23962403 model = ModelEmbWithEmbeddingUtils ()
2397- config = LoraConfig (target_modules = ["embed_tokens" , "lin0" ], init_lora_weights = False )
2398- model = get_peft_model (model , config )
2404+ model = get_peft_model (model , peft_config )
23992405
2400- with tempfile .TemporaryDirectory () as tmp_dirname :
2401- if save_embedding_layers == "auto" :
2402- # assert warning
2403- msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`."
2404- with pytest .warns (UserWarning , match = msg_start ):
2405- model .save_pretrained (tmp_dirname , save_embedding_layers = save_embedding_layers )
2406- else :
2407- model .save_pretrained (tmp_dirname , save_embedding_layers = save_embedding_layers )
2408- from safetensors .torch import load_file as safe_load_file
2409-
2410- state_dict = safe_load_file (os .path .join (tmp_dirname , "adapter_model.safetensors" ))
2411- if save_embedding_layers in ["auto" , True ]:
2412- assert "base_model.model.embed_tokens.base_layer.weight" in state_dict
2413- assert torch .allclose (
2414- model .base_model .model .embed_tokens .base_layer .weight ,
2415- state_dict ["base_model.model.embed_tokens.base_layer.weight" ],
2416- )
2417- else :
2418- assert "base_model.model.embed_tokens.base_layer.weight" not in state_dict
2419- del state_dict
2406+ if save_embedding_layers == "auto" :
2407+ # assert warning
2408+ msg_start = "Setting `save_embedding_layers` to `True` as embedding layers found in `target_modules`."
2409+ with pytest .warns (UserWarning , match = msg_start ):
2410+ model .save_pretrained (tmp_path , save_embedding_layers = save_embedding_layers )
2411+ else :
2412+ model .save_pretrained (tmp_path , save_embedding_layers = save_embedding_layers )
2413+
2414+ state_dict = safe_load_file (tmp_path / "adapter_model.safetensors" )
2415+ contains_embedding = "base_model.model.embed_tokens.base_layer.weight" in state_dict
2416+
2417+ if save_embedding_layers in ["auto" , True ]:
2418+ assert contains_embedding
2419+ assert torch .allclose (
2420+ model .base_model .model .embed_tokens .base_layer .weight ,
2421+ state_dict ["base_model.model.embed_tokens.base_layer.weight" ],
2422+ )
2423+ else :
2424+ assert not contains_embedding
24202425
24212426 @pytest .mark .parametrize ("save_embedding_layers" , ["auto" , True , False ])
2422- def test_targeting_lora_to_embedding_layer_non_transformers (self , save_embedding_layers ):
2427+ @pytest .mark .parametrize (
2428+ "peft_config" ,
2429+ [
2430+ (LoraConfig (target_modules = ["lin0" , "emb" ], init_lora_weights = False )),
2431+ (LoraConfig (target_modules = r"^emb" , init_lora_weights = False )),
2432+ ],
2433+ )
2434+ def test_save_pretrained_targeting_lora_to_embedding_layer_non_transformers (
2435+ self , save_embedding_layers , tmp_path , peft_config
2436+ ):
24232437 model = ModelEmbConv1D ()
2424- config = LoraConfig (target_modules = ["emb" , "lin0" ], init_lora_weights = False )
2425- model = get_peft_model (model , config )
2426-
2427- with tempfile .TemporaryDirectory () as tmp_dirname :
2428- if save_embedding_layers is True :
2429- with pytest .warns (
2430- UserWarning ,
2431- match = r"Could not identify embedding layer\(s\) because the model is not a 🤗 transformers model\." ,
2432- ):
2433- model .save_pretrained (tmp_dirname , save_embedding_layers = save_embedding_layers )
2434- else :
2435- model .save_pretrained (tmp_dirname , save_embedding_layers = save_embedding_layers )
2436- from safetensors .torch import load_file as safe_load_file
2438+ model = get_peft_model (model , peft_config )
2439+
2440+ if save_embedding_layers is True :
2441+ with pytest .warns (
2442+ UserWarning ,
2443+ match = r"Could not identify embedding layer\(s\) because the model is not a 🤗 transformers model\." ,
2444+ ):
2445+ model .save_pretrained (tmp_path , save_embedding_layers = save_embedding_layers )
2446+ else :
2447+ model .save_pretrained (tmp_path , save_embedding_layers = save_embedding_layers )
24372448
2438- state_dict = safe_load_file (os .path .join (tmp_dirname , "adapter_model.safetensors" ))
2439- assert "base_model.model.emb.base_layer.weight" not in state_dict
2440- del state_dict
2449+ state_dict = safe_load_file (tmp_path / "adapter_model.safetensors" )
2450+ assert "base_model.model.emb.base_layer.weight" not in state_dict
24412451
24422452 def test_load_resized_embedding_ignore_mismatched_sizes (self ):
24432453 # issue #1605
0 commit comments