Skip to content

Commit b142266

Browse files
authored
support additional_trainable_parameters (#295)
1 parent 98033fa commit b142266

File tree

7 files changed

+42
-12
lines changed

7 files changed

+42
-12
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- `--model_cache_dir`: 默认为`None`. 如果模型在本地已经有缓存, 且缓存路径并非ModelScope默认cache路径, 可以通过指定该参数从cache_dir中导入model和tokenizer.
1212
- `--sft_type`: 表示微调的方式, 默认是`'lora'`. 你可以选择的值包括: 'lora', 'full', 'longlora', 'qalora'. 如果你要使用qlora, 你需设置`--sft_type lora --quantization_bit 4`.
1313
- `--freeze_parameters`: 当sft_type指定为'full'时, 将模型最底部的参数进行freeze. 指定范围为0. ~ 1., 默认为`0.`. 该参数提供了lora与全参数微调的折中方案.
14+
- `--additional_trainable_parameters`: 作为freeze_parameters的补充, 只有在sft_type指定为'full'才允许被使用, 默认为`[]`. 例如你如果想训练50%的参数的情况下想额外训练embedding层, 你可以设置`--freeze_parameters 0.5 --additional_trainable_parameters transformer.wte`, 所有以`transformer.wte`开头的parameters都会被激活.
1415
- `--tuner_backend`: 表示lora, qlora的后端支持, 默认是`'swift'`. 你可以选择的值包括: 'swift', 'peft'.
1516
- `--template_type`: 表示使用的对话模板的类型, 默认是`'AUTO'`, 即根据`model_type`查找`MODEL_MAPPING`中的`template`. 可以选择的`template_type`可以查看`TEMPLATE_MAPPING.keys()`.
1617
- `--output_dir`: 表示ckpt存储的目录, 默认是`'output'`. 我们会在该目录后拼接`model_type`和微调版本号. 方便用户对不同模型进行多次对比实验, 而不需要改变`output_dir`命令行参数. 如果不需要拼接这些内容, 你需要额外指定参数`--add_output_dir_suffix false`.

examples/pytorch/llm/scripts/qwen_7b_chat/full_freeze_ddp/sft.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Experimental environment: 2 * A100
2-
# 2 * 78GB GPU memory
2+
# 2 * 80GB GPU memory
33
NPROC_PER_NODE=2 \
44
CUDA_VISIBLE_DEVICES=0,1 \
55
swift sft \
@@ -14,5 +14,6 @@ swift sft \
1414
--use_flash_attn true \
1515
--only_save_model true \
1616
--dataset codefuse-evol-instruction-zh \
17-
--freeze_parameters 0.2 \
17+
--freeze_parameters 0.25 \
18+
--additional_trainable_parameters transformer.wte \
1819
--preprocess_num_proc 4 \

swift/llm/infer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
read_multi_line, seed_everything, show_layers)
1616
from .utils import (InferArguments, Template, get_additional_saved_files,
1717
get_dataset, get_model_tokenizer, get_template, inference,
18-
inference_stream, set_generation_config)
18+
inference_stream, is_lora, set_generation_config)
1919

2020
logger = get_logger()
2121

@@ -138,8 +138,7 @@ def prepare_model_template(
138138
logger.info(f'generation_config: {generation_config}')
139139
set_generation_config(model, generation_config)
140140
# Preparing LoRA
141-
if args.sft_type in ('lora', 'qalora',
142-
'longlora') and args.ckpt_dir is not None:
141+
if is_lora(args.sft_type) and args.ckpt_dir is not None:
143142
model = Swift.from_pretrained(
144143
model, args.ckpt_dir, inference_mode=True)
145144

swift/llm/tuner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from swift.trainers import TrainerCallback
66
from swift.tuners import (LongLoRAConfig, LongLoRAModelType, LoraConfig,
77
LoRAConfig, NEFTuneConfig, Swift)
8-
from swift.utils import freeze_model_parameters, get_logger
8+
from swift.utils import (activate_model_parameters, freeze_model_parameters,
9+
get_logger)
910
from .utils import SftArguments, find_all_linear_for_lora, is_lora
1011

1112
logger = get_logger()
@@ -76,6 +77,9 @@ def prepare_model(model, args: SftArguments):
7677
elif args.sft_type == 'full':
7778
if args.freeze_parameters > 0:
7879
freeze_model_parameters(model, args.freeze_parameters)
80+
if len(args.additional_trainable_parameters) > 0:
81+
activate_model_parameters(model,
82+
args.additional_trainable_parameters)
7983
else:
8084
raise ValueError(f'args.sft_type: {args.sft_type}')
8185

swift/llm/utils/argument.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class SftArguments:
3939

4040
sft_type: Literal['lora', 'full', 'longlora', 'qalora'] = 'lora'
4141
freeze_parameters: float = 0. # 0 ~ 1
42+
additional_trainable_parameters: List[str] = field(default_factory=list)
4243
tuner_backend: Literal['swift', 'peft'] = 'swift'
4344
template_type: str = field(
4445
default='AUTO',
@@ -211,6 +212,9 @@ def __post_init__(self) -> None:
211212
assert self.freeze_parameters == 0., (
212213
'lora does not support `freeze_parameters`, please set `--sft_type full`'
213214
)
215+
assert len(self.additional_trainable_parameters) == 0, (
216+
'lora does not support `additional_trainable_parameters`, please set `--sft_type full`'
217+
)
214218
if 'int4' in self.model_type or 'int8' in self.model_type:
215219
assert self.quantization_bit == 0, 'int4 and int8 models do not need to be quantized again.'
216220
if self.learning_rate is None:
@@ -221,12 +225,16 @@ def __post_init__(self) -> None:
221225
else:
222226
self.only_save_model = True
223227
elif self.sft_type == 'full':
224-
assert 0 <= self.freeze_parameters < 1
228+
assert 0 <= self.freeze_parameters <= 1
225229
assert self.quantization_bit == 0, 'Full parameter fine-tuning does not support quantization.'
226230
assert self.dtype != 'fp16', (
227231
"Fine-tuning with dtype=='fp16' can lead to NaN issues. "
228232
'Please use fp32+AMP or bf16 to perform full parameter fine-tuning.'
229233
)
234+
if isinstance(self.additional_trainable_parameters, str):
235+
self.additional_trainable_parameters = [
236+
self.additional_trainable_parameters
237+
]
230238
if self.learning_rate is None:
231239
self.learning_rate = 2e-5
232240
if self.only_save_model is None:

swift/utils/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from .run_utils import get_main
99
from .tb_utils import (TB_COLOR, TB_COLOR_SMOOTH, plot_images,
1010
read_tensorboard_file, tensorboard_smoothing)
11-
from .torch_utils import (broadcast_string, freeze_model_parameters,
12-
get_dist_setting, get_model_info, is_ddp_plus_mp,
13-
is_dist, is_local_master, is_master,
14-
is_on_same_device, seed_everything, show_layers,
15-
time_synchronize)
11+
from .torch_utils import (activate_model_parameters, broadcast_string,
12+
freeze_model_parameters, get_dist_setting,
13+
get_model_info, is_ddp_plus_mp, is_dist,
14+
is_local_master, is_master, is_on_same_device,
15+
seed_everything, show_layers, time_synchronize)
1616
from .utils import (add_version_to_work_dir, check_json_format, lower_bound,
1717
parse_args, read_multi_line, test_time, upper_bound)

swift/utils/torch_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,23 @@ def freeze_model_parameters(model: Module, freeze_parameters: float) -> None:
131131
p.requires_grad = False
132132

133133

134+
def activate_model_parameters(
135+
model: Module, additional_trainable_parameters: List[int]) -> None:
136+
if len(additional_trainable_parameters) == 0:
137+
return
138+
has_activate = False
139+
for n, p in model.named_parameters():
140+
for additional_tp in additional_trainable_parameters:
141+
if n.startswith(additional_tp):
142+
p.requires_grad = True
143+
has_activate = True
144+
if not has_activate:
145+
logger.warning(
146+
'len(additional_trainable_parameters) > 0 but no parameters are activated.'
147+
f'additional_trainable_parameters: {additional_trainable_parameters}'
148+
)
149+
150+
134151
def broadcast_string(string: Optional[str], buffer_size: int = 1024) -> str:
135152
"""String broadcasting in case of DDP
136153
string: main rank: str

0 commit comments

Comments
 (0)