Skip to content

Commit 75db766

Browse files
fix scedit bug (#290)
1 parent 4c6a2c5 commit 75db766

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

swift/llm/sft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
234234
class TrainerAdapterCallback(TrainerCallback):
235235

236236
def on_train_begin(*args, **kwargs):
237-
model.set_active_adapters(model.adapters.keys(), offload='meta')
237+
if hasattr(model, 'set_active_adapters'):
238+
model.set_active_adapters(
239+
model.adapters.keys(), offload='meta')
238240

239241
trainer = Seq2SeqTrainer(
240242
model=model,

swift/tuners/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ def __init__(self,
4848
extra_state_keys.extend(model.extra_state_keys)
4949
model = model.base_model
5050

51+
new_adapters = []
5152
if isinstance(config, SwiftConfig):
5253
if DEFAULT_ADAPTER not in self.adapters:
5354
self.adapters[DEFAULT_ADAPTER] = self._prepare_model(
5455
model, config, DEFAULT_ADAPTER)
56+
new_adapters.append(DEFAULT_ADAPTER)
5557
else:
5658
logger.warn(
5759
f'Adater {DEFAULT_ADAPTER} has been patched, skip.')
@@ -61,6 +63,7 @@ def __init__(self,
6163
if adapter_name not in self.adapters:
6264
self.adapters[adapter_name] = self._prepare_model(
6365
model, _config, adapter_name)
66+
new_adapters.append(adapter_name)
6467
else:
6568
logger.warn(
6669
f'Adater {adapter_name} has been patched, skip.')
@@ -76,14 +79,15 @@ def forward(self, *args, **kwargs):
7679
signature(self.base_model.forward).parameters.values())
7780
forward.__signature__ = Signature(_parameters)
7881
self.forward = MethodType(forward, self)
79-
for adapter_name in self.adapters:
82+
for adapter_name in new_adapters:
8083
self.activate_adapter(adapter_name)
8184

8285
if inference_mode:
8386
self.eval()
8487
else:
85-
for output in self.adapters.values():
86-
output.mark_trainable_callback(model)
88+
for key, output in self.adapters.items():
89+
if key in new_adapters:
90+
output.mark_trainable_callback(model)
8791
if self.extra_state_keys:
8892
for n, p in model.named_parameters():
8993
if any(

swift/tuners/scetuning/scetuning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def _forward_decoder_mode(self, *args, **kwargs):
205205
tuner_op = SCETunerModule(
206206
name=config.tuner_op,
207207
adapter_name=adapter_name,
208+
module_key=str(tuner_id),
208209
dim=dims[tuner_id],
209210
tuner_length=int(dims[tuner_id] * config.down_ratio))
210211
setattr(t_module, f'scetuner_{adapter_name}', tuner_op)
@@ -244,6 +245,7 @@ class SCETunerModule(nn.Module, ActivationMixin):
244245
def __init__(self,
245246
name,
246247
adapter_name,
248+
module_key,
247249
dim,
248250
tuner_length,
249251
tuner_type=None,
@@ -252,7 +254,7 @@ def __init__(self,
252254
zero_init_last=True,
253255
use_bias=True):
254256
super(SCETunerModule, self).__init__()
255-
super(nn.Module, self).__init__('')
257+
super(nn.Module, self).__init__(module_key)
256258
self.name = name
257259
self.adapter_name = adapter_name
258260
self.dim = dim
@@ -271,6 +273,7 @@ def forward(self, x, x_shortcut=None, use_shortcut=True, **kwargs):
271273
if not self.is_activated(self.adapter_name):
272274
return x
273275
if self.name == 'SCEAdapter':
276+
self.tuner_op.to(x.device)
274277
out = self.tuner_op(x)
275278
else:
276279
raise Exception(f'Error tuner op {self.name}')

0 commit comments

Comments
 (0)