diff --git a/sentence_transformers/peft_mixin.py b/sentence_transformers/peft_mixin.py index da7f2494b..fdd4cd44d 100644 --- a/sentence_transformers/peft_mixin.py +++ b/sentence_transformers/peft_mixin.py @@ -141,3 +141,24 @@ def get_adapter_state_dict(self, *args, **kwargs) -> dict: https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict """ ... # Implementation handled by the wrapper + + def merge_adapter(self, *args, **kwargs): + """ + Merges the adapter into the base model and unloads it by calling the underlying auto_model's + `merge_and_unload` method. It also sets the `_hf_peft_config_loaded` flag to False after merging. + + Returns: + The merged auto_model. + + Raises: + ValueError: If the current model does not support merging using the `merge_and_unload` method. + """ + self.check_peft_compatible_model() + if not (hasattr(self[0].auto_model, "merge_and_unload") and callable(self[0].auto_model.merge_and_unload)): + raise ValueError( + "The current model does not support merging using merge_and_unload. " + "Please ensure that you have added a PEFT adapter using model.add_adapter(...) before merging." + ) + merged_model = self[0].auto_model.merge_and_unload(*args, **kwargs) + merged_model._hf_peft_config_loaded = False + return merged_model diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 1336ba9e5..bda42cab5 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -16,6 +16,7 @@ import pytest import torch from huggingface_hub import CommitInfo, HfApi, RepoUrl +from peft import LoraConfig, TaskType from torch import nn from transformers.utils import is_peft_available @@ -441,7 +442,7 @@ def transformers_init(*args, **kwargs): @pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test PEFT support.") def test_load_checkpoint_with_peft_and_lora() -> None: - from peft import LoraConfig, PeftModel, TaskType + from peft import PeftModel peft_config = LoraConfig( target_modules=["query", "key", "value"], @@ -748,7 +749,7 @@ def test_multiple_adapters() -> None: text = "Hello, World!" model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") vec_initial = model.encode(text) - from peft import LoraConfig, TaskType, get_model_status + from peft import get_model_status # Adding a fresh adapter peft_config = LoraConfig( @@ -834,3 +835,56 @@ def test_clip(): tokenized = model.tokenize(["This is my text sentence"]) assert "input_ids" in tokenized assert tokenized["input_ids"].shape == (1, 5) + + +@pytest.mark.skipif(not pytest.importorskip("peft"), reason="PEFT must be available to test merge_adapter.") +def test_merge_adapter_success(): + # Load a model from the hub; note that by default the model may not be seen as PEFT-compatible, + # so we monkey-patch it to force compatibility. + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-lora") + + # Bind the merge_adapter method from PeftAdapterMixin onto the instance, if not already available. + if not hasattr(model, "merge_adapter"): + from sentence_transformers.peft_mixin import PeftAdapterMixin + + model.merge_adapter = PeftAdapterMixin.merge_adapter.__get__(model, type(model)) + + # Monkey-patch has_peft_compatible_model to always return True. + model.has_peft_compatible_model = lambda: True + + # If the underlying auto_model does not have merge_and_unload, + # add a dummy implementation that sets _hf_peft_config_loaded to False. + if not (hasattr(model[0].auto_model, "merge_and_unload") and callable(model[0].auto_model.merge_and_unload)): + + def dummy_merge_and_unload(*args, **kwargs): + model[0].auto_model._hf_peft_config_loaded = False + return model[0].auto_model + + model[0].auto_model.merge_and_unload = dummy_merge_and_unload + # Initialize the flag as True so we can verify the change. + model[0].auto_model._hf_peft_config_loaded = True + + # Call merge_adapter; it should call dummy_merge_and_unload. + merged_auto_model = model.merge_adapter() + # After merging, the _hf_peft_config_loaded flag should be False. + assert not merged_auto_model._hf_peft_config_loaded, "Expected _hf_peft_config_loaded to be False after merging." + + +@pytest.mark.skipif(not pytest.importorskip("peft"), reason="PEFT must be available to test merge_adapter.") +def test_merge_adapter_incompatible(): + # Load a model from the hub. + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-lora") + + # Bind the merge_adapter method from PeftAdapterMixin if not already bound. + if not hasattr(model, "merge_adapter"): + from sentence_transformers.peft_mixin import PeftAdapterMixin + + model.merge_adapter = PeftAdapterMixin.merge_adapter.__get__(model, type(model)) + + # Force the compatibility check to pass. + model.has_peft_compatible_model = lambda: True + # Simulate an incompatible auto_model by removing (or nullifying) merge_and_unload. + setattr(model[0].auto_model, "merge_and_unload", None) + + with pytest.raises(ValueError, match="The current model does not support merging using merge_and_unload"): + model.merge_adapter()