diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 6db6577647..cbe61446ce 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -57,6 +57,21 @@ logger = get_logger() +def _get_peft_selected_adapters(model: PeftModel) -> List[str]: + """Resolve active adapter name(s) for PeftModel.save_pretrained(selected_adapters=...). + + When users pass a custom adapter_name (e.g. 'vision_only_lora') to Swift.prepare_model, + we must pass the same name(s) to save_pretrained; otherwise PEFT raises because + supported adapter names do not include 'default'. See: gh#8336. + """ + active = getattr(model, 'active_adapter', None) or getattr(model, 'active_adapters', None) + if active is None: + return ['default'] + if isinstance(active, str): + return [active] + return list(active) + + class SwiftMixin: FLASH_CKPT_WAIT_TIMEOUT = 1800 @@ -284,7 +299,7 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None): if isinstance(_unwrap_model, supported_classes): save_kwargs = {'state_dict': state_dict, 'max_shard_size': self.args.max_shard_size} if isinstance(_unwrap_model, PeftModel): - save_kwargs['selected_adapters'] = ['default'] + save_kwargs['selected_adapters'] = _get_peft_selected_adapters(_unwrap_model) if use_flash_ckpt: _unwrap_model.save_pretrained( output_dir, @@ -321,7 +336,7 @@ def _save_model(self, output_dir: Optional[str] = None, state_dict=None): if self.model.__class__.__name__ != 'SentenceTransformer': save_kwargs = {'state_dict': state_dict, 'max_shard_size': self.args.max_shard_size} if isinstance(self.model, PeftModel): - save_kwargs['selected_adapters'] = ['default'] + save_kwargs['selected_adapters'] = _get_peft_selected_adapters(self.model) if use_flash_ckpt: self.model.save_pretrained( output_dir, diff --git a/tests/tuners/test_peft.py b/tests/tuners/test_peft.py index 6d9144fafc..6c9c45c44c 100644 --- a/tests/tuners/test_peft.py +++ b/tests/tuners/test_peft.py @@ -13,6 +13,7 @@ from peft.utils import WEIGHTS_NAME from torch import nn +from swift.trainers.mixin import _get_peft_selected_adapters from swift.tuners import AdaLoraConfig, LoraConfig, LoRAConfig, Swift, get_peft_model @@ -157,3 +158,37 @@ def test_peft_lora_dtype(self): self.assertTrue(model3.base_model.model.bert.encoder.layer[0].attention.self.key.lora_A.default.weight.dtype == torch.float32) self.assertTrue(isinstance(model3.peft_config['default'], peft.LoraConfig)) + + def test_get_peft_selected_adapters_custom_name(self): + """Check selected_adapters respects custom adapter_name (gh#8336).""" + model = SbertForSequenceClassification(SbertConfig()) + lora_config = LoraConfig(target_modules=['query', 'key', 'value']) + model = Swift.prepare_model(model, lora_config, adapter_name='vision_only_lora') + # SwiftModel wraps base_model; trainer unwraps to PeftModel for save. + inner = getattr(model, 'base_model', model) + self.assertEqual(_get_peft_selected_adapters(inner), ['vision_only_lora']) + + def test_get_peft_selected_adapters_default(self): + """Check selected_adapters falls back to default when no custom name.""" + model = SbertForSequenceClassification(SbertConfig()) + lora_config = LoraConfig(target_modules=['query', 'key', 'value']) + model = Swift.prepare_model(model, lora_config) + inner = getattr(model, 'base_model', model) + self.assertEqual(_get_peft_selected_adapters(inner), ['default']) + + def test_get_peft_selected_adapters_mock(self): + """Check _get_peft_selected_adapters with active_adapter/active_adapters (gh#8336).""" + class MockPeft: + pass + # None -> default + m = MockPeft() + m.active_adapter = None + m.active_adapters = None + self.assertEqual(_get_peft_selected_adapters(m), ['default']) + # str + m.active_adapter = 'my_lora' + self.assertEqual(_get_peft_selected_adapters(m), ['my_lora']) + # list (active_adapters takes precedence when active_adapter is None) + m.active_adapter = None + m.active_adapters = ['a', 'b'] + self.assertEqual(_get_peft_selected_adapters(m), ['a', 'b'])