Skip to content

Commit 4c6a2c5

Browse files
Fix offload (#288)
1 parent 1463f74 commit 4c6a2c5

File tree

6 files changed

+22
-17
lines changed

6 files changed

+22
-17
lines changed

swift/aigc/animatediff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def state_dict(self,
525525
**kwargs):
526526
state_dict = self.state_dict_origin()
527527
return {
528-
key: value
528+
key.replace('base_layer.', ''): value
529529
for key, value in state_dict.items()
530530
if 'lora' not in key
531531
}

swift/llm/sft.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,19 @@
99
from modelscope import BitsAndBytesConfig, GenerationConfig
1010

1111
from swift.trainers import (IntervalStrategy, Seq2SeqTrainer,
12-
Seq2SeqTrainingArguments)
13-
from swift.tuners import (LongLoRAConfig, LongLoRAModelType, LoraConfig,
14-
LoRAConfig, NEFTuneConfig, Swift)
12+
Seq2SeqTrainingArguments, TrainerCallback)
1513
from swift.utils import (check_json_format, compute_acc_metrics,
16-
compute_nlg_metrics, freeze_model_parameters,
17-
get_dist_setting, get_logger, get_main,
18-
get_model_info, is_ddp_plus_mp, is_dist, is_master,
19-
plot_images, preprocess_logits_for_metrics,
14+
compute_nlg_metrics, get_dist_setting, get_logger,
15+
get_main, get_model_info, is_ddp_plus_mp, is_dist,
16+
is_master, plot_images, preprocess_logits_for_metrics,
2017
seed_everything, show_layers)
2118
from .tuner import prepare_model
2219
from .utils import (LazyLLMDataset, SftArguments, Template,
2320
add_self_cognition_dataset, data_collate_fn, dataset_map,
24-
find_all_linear_for_lora, get_additional_saved_files,
25-
get_dataset, get_model_tokenizer, get_template,
26-
get_time_info, print_example, set_generation_config,
27-
sort_by_max_length, stat_dataset)
21+
get_additional_saved_files, get_dataset,
22+
get_model_tokenizer, get_template, get_time_info,
23+
print_example, set_generation_config, sort_by_max_length,
24+
stat_dataset)
2825

2926
logger = get_logger()
3027

@@ -234,13 +231,19 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
234231
if args.check_model_is_latest is False:
235232
trainer_kwargs['check_model'] = False
236233

234+
class TrainerAdapterCallback(TrainerCallback):
235+
236+
def on_train_begin(*args, **kwargs):
237+
model.set_active_adapters(model.adapters.keys(), offload='meta')
238+
237239
trainer = Seq2SeqTrainer(
238240
model=model,
239241
args=training_args,
240242
data_collator=data_collator,
241243
train_dataset=train_dataset,
242244
eval_dataset=val_dataset,
243245
tokenizer=tokenizer,
246+
callbacks=[TrainerAdapterCallback()],
244247
**trainer_kwargs)
245248
trainer.sft_args = args
246249
if is_master():

swift/trainers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .dpo_trainers import DPOTrainer
99
from .trainers import Seq2SeqTrainer, Trainer
1010
from .utils import EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, \
11-
IntervalStrategy, SchedulerType, ShardedDDPOption
11+
IntervalStrategy, SchedulerType, ShardedDDPOption, TrainerCallback
1212
else:
1313
_import_structure = {
1414
'arguments': ['Seq2SeqTrainingArguments', 'TrainingArguments'],
@@ -17,7 +17,7 @@
1717
'utils': [
1818
'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend',
1919
'HubStrategy', 'IntervalStrategy', 'SchedulerType',
20-
'ShardedDDPOption'
20+
'ShardedDDPOption', 'TrainerCallback'
2121
]
2222
}
2323

swift/trainers/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Union
77

88
from torch.nn import Module
9+
from transformers.trainer_callback import TrainerCallback
910
from transformers.trainer_utils import (EvaluationStrategy, FSDPOption,
1011
HPSearchBackend, HubStrategy,
1112
IntervalStrategy, SchedulerType)

swift/tuners/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,20 +464,20 @@ def set_active_adapters(self,
464464

465465
adapter_names = set(adapter_names)
466466
for adapter_name in (adapter_names & set(self.adapters.keys())):
467-
self.activate_adapter(adapter_name)
467+
self.activate_adapter(adapter_name, offload)
468468

469469
for adapter_name in (set(self.adapters.keys()) - adapter_names):
470470
self.deactivate_adapter(adapter_name, offload)
471471

472-
def activate_adapter(self, adapter_name):
472+
def activate_adapter(self, adapter_name, offload=None):
473473
if adapter_name not in self.adapters:
474474
logger.warning(
475475
f'{adapter_name} not in adapters: {self.adapters.keys()}')
476476
return
477477

478478
from .mapping import SWIFT_MAPPING
479479
SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
480-
.activate_adapter(self.base_model, adapter_name, True)
480+
.activate_adapter(self.base_model, adapter_name, True, offload)
481481

482482
def deactivate_adapter(self, adapter_name, offload=None):
483483
if adapter_name not in self.adapters:

swift/tuners/lora_layers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from peft.tuners.lora import Conv2d as _Conv2d
1717
from peft.tuners.lora import Embedding as _Embedding
1818
from peft.tuners.lora import Linear as _Linear
19+
from peft.tuners.lora import LoraLayer
1920
from peft.tuners.lora import LoraModel as _LoraModel
2021
from peft.tuners.lora.tp_layer import LoraParallelLinear as _LoraParallelLinear
2122
from peft.tuners.tuners_utils import BaseTunerLayer

0 commit comments

Comments
 (0)