Skip to content

Commit 34ec333

Browse files
Support lora regex (#1375)
1 parent 05719cc commit 34ec333

File tree

7 files changed

+29
-23
lines changed

7 files changed

+29
-23
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
- `--bnb_4bit_use_double_quant`: 是否在4bit量化时开启double量化, 默认为`True`. 当quantization_bit为0时, 该参数无效.
6363
- `--bnb_4bit_quant_storage`: 默认值为`None`. 量化参数的存储类型. 若`quantization_bit`设置为0, 则该参数失效.
6464
- `--lora_target_modules`: 指定lora模块, 默认为`['DEFAULT']`. 如果lora_target_modules传入`'DEFAULT'` or `'AUTO'`, 则根据`model_type`查找`MODEL_MAPPING`中的`lora_target_modules`(默认指定为qkv). 如果传入`'ALL'`, 则将所有的Linear层(不含head)指定为lora模块. 如果传入`'EMBEDDING'`, 则Embedding层指定为lora模块. 如果内存允许, 建议设置成'ALL'. 当然, 你也可以设置`['ALL', 'EMBEDDING']`, 将所有的Linear和embedding层指定为lora模块. 该参数只有当`sft_type`指定为'lora'时才生效.
65+
- `--lora_target_regex`: 指定lora模块的regex表达式, `Optional[str]`类型. 默认为`None`, 如果该值传入, 则lora_target_modules不生效.
6566
- `--lora_rank`: 默认为`8`. 只有当`sft_type`指定为'lora'时才生效.
6667
- `--lora_alpha`: 默认为`32`. 只有当`sft_type`指定为'lora'时才生效.
6768
- `--lora_dropout_p`: 默认为`0.05`, 只有当`sft_type`指定为'lora'时才生效.

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
- `--bnb_4bit_use_double_quant`: Whether to enable double quantization for 4bit quantization, default is `True`. Has no effect when quantization_bit is 0.
6464
- `--bnb_4bit_quant_storage`: Default vlaue `None`.This sets the storage type to pack the quanitzed 4-bit prarams. Has no effect when quantization_bit is 0.
6565
- `--lora_target_modules`: Specify lora modules, default is `['DEFAULT']`. If lora_target_modules is passed `'DEFAULT'` or `'AUTO'`, look up `lora_target_modules` in `MODEL_MAPPING` based on `model_type` (default specifies qkv). If passed `'ALL'`, all Linear layers (excluding head) will be specified as lora modules. If passed `'EMBEDDING'`, Embedding layer will be specified as lora module. If memory allows, setting to 'ALL' is recommended. You can also set `['ALL', 'EMBEDDING']` to specify all Linear and embedding layers as lora modules. This parameter only takes effect when `sft_type` is 'lora'.
66+
- `--lora_target_regex`: The lora target regex in `Optional[str]`. default is `None`. If this argument is specified, the `lora_target_modules` will have no effect.
6667
- `--lora_rank`: Default is `8`. Only takes effect when `sft_type` is 'lora'.
6768
- `--lora_alpha`: Default is `32`. Only takes effect when `sft_type` is 'lora'.
6869
- `--lora_dropout_p`: Default is `0.05`, only takes effect when `sft_type` is 'lora'.

examples/pytorch/llm/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _preprocess_stsb(dataset: HfDataset) -> HfDataset:
7575
return HfDataset.from_dict({'query': query, 'response': response})
7676

7777

78-
register_dataset(CustomDatasetName.stsb_en, 'huangjintao/stsb', None, _preprocess_stsb, get_dataset_from_repo)
78+
register_dataset(CustomDatasetName.stsb_en, 'swift/stsb', None, _preprocess_stsb, get_dataset_from_repo)
7979

8080
if __name__ == '__main__':
8181
# The Shell script can view `examples/pytorch/llm/scripts/custom`.

swift/llm/tuner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,12 @@ def prepare_model(model, args: SftArguments):
103103
handle_modules_to_save(model, args)
104104
if args.init_lora_weights and args.init_lora_weights.lower() in ('true', 'false'):
105105
args.init_lora_weights = args.init_lora_weights.lower() in ('true', 'True')
106+
if args.lora_target_regex:
107+
logger.info(f'Value of lora_target_modules: {args.lora_target_modules} will have no effect '
108+
f'because lora_target_regex value: {args.lora_target_regex} exists.')
106109
lora_kwargs = {
107110
'r': args.lora_rank,
108-
'target_modules': args.lora_target_modules,
111+
'target_modules': args.lora_target_regex or args.lora_target_modules,
109112
'lora_alpha': args.lora_alpha,
110113
'lora_dropout': args.lora_dropout_p,
111114
'bias': args.lora_bias_trainable,

swift/llm/utils/argument.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ class SftArguments(ArgumentsBase):
492492
bnb_4bit_quant_storage: Optional[str] = None
493493
# lora
494494
lora_target_modules: List[str] = field(default_factory=lambda: ['DEFAULT'])
495+
lora_target_regex: Optional[str] = None
495496
lora_rank: int = 8
496497
lora_alpha: int = 32
497498
lora_dropout_p: float = 0.05

swift/llm/utils/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,7 @@ def _preprocess_aishell1_dataset(dataset: HfDataset) -> HfDataset:
719719

720720

721721
def _preprocess_video_chatgpt(dataset: HfDataset) -> HfDataset:
722-
url = 'https://modelscope.cn/datasets/huangjintao/VideoChatGPT/resolve/master/videos.zip'
722+
url = 'https://modelscope.cn/datasets/swift/VideoChatGPT/resolve/master/videos.zip'
723723
local_dir = MediaCache.download(url, 'video_chatgpt')
724724
local_dir = os.path.join(local_dir, 'Test_Videos')
725725
# only `.mp4`
@@ -742,7 +742,7 @@ def _preprocess_video_chatgpt(dataset: HfDataset) -> HfDataset:
742742

743743
register_dataset(
744744
DatasetName.video_chatgpt,
745-
'huangjintao/VideoChatGPT', ['Generic', 'Temporal', 'Consistency'],
745+
'swift/VideoChatGPT', ['Generic', 'Temporal', 'Consistency'],
746746
_preprocess_video_chatgpt,
747747
get_dataset_from_repo,
748748
split=['test'],
@@ -1832,7 +1832,7 @@ def preprocess(row):
18321832

18331833
register_dataset(
18341834
DatasetName.sharegpt,
1835-
'huangjintao/sharegpt', ['common-zh', 'computer-zh', 'unknow-zh', 'common-en', 'computer-en'],
1835+
'swift/sharegpt', ['common-zh', 'computer-zh', 'unknow-zh', 'common-en', 'computer-en'],
18361836
preprocess_sharegpt,
18371837
get_dataset_from_repo,
18381838
tags=['chat', 'general', 'multi-round'])
@@ -1977,7 +1977,7 @@ def _repair_conversations_agent_instruct(s: str) -> List[Dict[str, Any]]:
19771977

19781978
register_dataset(
19791979
DatasetName.agent_instruct_all_en,
1980-
'huangjintao/AgentInstruct_copy', ['alfworld', 'db', 'kg', 'mind2web', 'os', 'webshop'],
1980+
'swift/AgentInstruct_copy', ['alfworld', 'db', 'kg', 'mind2web', 'os', 'webshop'],
19811981
ConversationsPreprocessor('human', 'gpt', repair_conversations=_repair_conversations_agent_instruct),
19821982
get_dataset_from_repo,
19831983
tags=['chat', 'agent', 'multi-round'])

swift/llm/utils/model.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,7 +2442,7 @@ def _output_device_map_hook(module, input, output):
24422442
hf_model_id='mistralai/Mistral-7B-v0.1')
24432443
@register_model(
24442444
ModelType.codestral_22b,
2445-
'huangjintao/Codestral-22B-v0.1',
2445+
'swift/Codestral-22B-v0.1',
24462446
LoRATM.llama,
24472447
TemplateType.default_generation,
24482448
requires=['transformers>=4.34'],
@@ -4033,7 +4033,7 @@ def get_model_tokenizer_deepseek_vl(model_dir: str,
40334033

40344034
@register_model(
40354035
ModelType.llama3_70b_instruct_awq,
4036-
'huangjintao/Meta-Llama-3-70B-Instruct-AWQ',
4036+
'swift/Meta-Llama-3-70B-Instruct-AWQ',
40374037
LoRATM.llama,
40384038
TemplateType.llama3,
40394039
requires=['autoawq'],
@@ -4044,7 +4044,7 @@ def get_model_tokenizer_deepseek_vl(model_dir: str,
40444044
hf_model_id='study-hjt/Meta-Llama-3-70B-Instruct-AWQ')
40454045
@register_model(
40464046
ModelType.llama3_70b_instruct_int8,
4047-
'huangjintao/Meta-Llama-3-70b-Instruct-GPTQ-Int8',
4047+
'swift/Meta-Llama-3-70b-Instruct-GPTQ-Int8',
40484048
LoRATM.llama,
40494049
TemplateType.llama3,
40504050
requires=['auto_gptq'],
@@ -4055,7 +4055,7 @@ def get_model_tokenizer_deepseek_vl(model_dir: str,
40554055
hf_model_id='study-hjt/Meta-Llama-3-70B-Instruct-GPTQ-Int8')
40564056
@register_model(
40574057
ModelType.llama3_70b_instruct_int4,
4058-
'huangjintao/Meta-Llama-3-70B-Instruct-GPTQ-Int4',
4058+
'swift/Meta-Llama-3-70B-Instruct-GPTQ-Int4',
40594059
LoRATM.llama,
40604060
TemplateType.llama3,
40614061
requires=['auto_gptq'],
@@ -4066,7 +4066,7 @@ def get_model_tokenizer_deepseek_vl(model_dir: str,
40664066
hf_model_id='study-hjt/Meta-Llama-3-70B-Instruct-GPTQ-Int4')
40674067
@register_model(
40684068
ModelType.llama3_8b_instruct_awq,
4069-
'huangjintao/Meta-Llama-3-8B-Instruct-AWQ',
4069+
'swift/Meta-Llama-3-8B-Instruct-AWQ',
40704070
LoRATM.llama,
40714071
TemplateType.llama3,
40724072
requires=['autoawq'],
@@ -4077,7 +4077,7 @@ def get_model_tokenizer_deepseek_vl(model_dir: str,
40774077
hf_model_id='study-hjt/Meta-Llama-3-8B-Instruct-AWQ')
40784078
@register_model(
40794079
ModelType.llama3_8b_instruct_int8,
4080-
'huangjintao/Meta-Llama-3-8B-Instruct-GPTQ-Int8',
4080+
'swift/Meta-Llama-3-8B-Instruct-GPTQ-Int8',
40814081
LoRATM.llama,
40824082
TemplateType.llama3,
40834083
requires=['auto_gptq'],
@@ -4088,7 +4088,7 @@ def get_model_tokenizer_deepseek_vl(model_dir: str,
40884088
hf_model_id='study-hjt/Meta-Llama-3-8B-Instruct-GPTQ-Int8')
40894089
@register_model(
40904090
ModelType.llama3_8b_instruct_int4,
4091-
'huangjintao/Meta-Llama-3-8B-Instruct-GPTQ-Int4',
4091+
'swift/Meta-Llama-3-8B-Instruct-GPTQ-Int4',
40924092
LoRATM.llama,
40934093
TemplateType.llama3,
40944094
requires=['auto_gptq'],
@@ -5106,7 +5106,7 @@ def get_model_tokenizer_llava_hf(model_dir: str, *args, **kwargs):
51065106

51075107
@register_model(
51085108
ModelType.llava1_5_13b_instruct,
5109-
'huangjintao/llava-1.5-13b-hf',
5109+
'swift/llava-1.5-13b-hf',
51105110
LoRATM.llama,
51115111
TemplateType.llava1_5,
51125112
eos_token='</s>',
@@ -5123,7 +5123,7 @@ def get_model_tokenizer_llava_hf(model_dir: str, *args, **kwargs):
51235123
hf_model_id='llava-hf/llava-1.5-13b-hf')
51245124
@register_model(
51255125
ModelType.llava1_5_7b_instruct,
5126-
'huangjintao/llava-1.5-7b-hf',
5126+
'swift/llava-1.5-7b-hf',
51275127
LoRATM.llama,
51285128
TemplateType.llava1_5,
51295129
eos_token='</s>',
@@ -5147,7 +5147,7 @@ def get_model_tokenizer_llava_1_5(*args, **kwargs):
51475147

51485148
@register_model(
51495149
ModelType.llava1_6_vicuna_7b_instruct,
5150-
'huangjintao/llava-v1.6-vicuna-7b-hf',
5150+
'swift/llava-v1.6-vicuna-7b-hf',
51515151
LoRATM.llama,
51525152
TemplateType.llava_vicuna,
51535153
support_vllm=True,
@@ -5163,7 +5163,7 @@ def get_model_tokenizer_llava_1_5(*args, **kwargs):
51635163
hf_model_id='llava-hf/llava-v1.6-vicuna-7b-hf')
51645164
@register_model(
51655165
ModelType.llava1_6_vicuna_13b_instruct,
5166-
'huangjintao/llava-v1.6-vicuna-13b-hf',
5166+
'swift/llava-v1.6-vicuna-13b-hf',
51675167
LoRATM.llama,
51685168
TemplateType.llava_vicuna,
51695169
support_vllm=True,
@@ -5179,7 +5179,7 @@ def get_model_tokenizer_llava_1_5(*args, **kwargs):
51795179
hf_model_id='llava-hf/llava-v1.6-vicuna-13b-hf')
51805180
@register_model(
51815181
ModelType.llava1_6_mistral_7b_instruct,
5182-
'huangjintao/llava-v1.6-mistral-7b-hf',
5182+
'swift/llava-v1.6-mistral-7b-hf',
51835183
LoRATM.llama,
51845184
TemplateType.llava_mistral,
51855185
support_vllm=True,
@@ -5202,7 +5202,7 @@ def get_model_tokenizer_llava_next(*args, **kwargs):
52025202

52035203
@register_model(
52045204
ModelType.llava1_6_yi_34b_instruct,
5205-
'huangjintao/llava-v1.6-34b-hf',
5205+
'swift/llava-v1.6-34b-hf',
52065206
LoRATM.llama,
52075207
TemplateType.llava_yi,
52085208
support_vllm=True,
@@ -5226,7 +5226,7 @@ def get_model_tokenizer_llava_next_yi(*args, **kwargs):
52265226

52275227
@register_model(
52285228
ModelType.llava_next_video_7b_dpo_instruct,
5229-
'huangjintao/LLaVA-NeXT-Video-7B-DPO-hf',
5229+
'swift/LLaVA-NeXT-Video-7B-DPO-hf',
52305230
LoRATM.llama,
52315231
TemplateType.llava_next_video,
52325232
support_flash_attn=True,
@@ -5235,7 +5235,7 @@ def get_model_tokenizer_llava_next_yi(*args, **kwargs):
52355235
hf_model_id='llava-hf/LLaVA-NeXT-Video-7B-DPO-hf')
52365236
@register_model(
52375237
ModelType.llava_next_video_7b_32k_instruct,
5238-
'huangjintao/LLaVA-NeXT-Video-7B-32K-hf',
5238+
'swift/LLaVA-NeXT-Video-7B-32K-hf',
52395239
LoRATM.llama,
52405240
TemplateType.llava_next_video,
52415241
support_flash_attn=True,
@@ -5244,7 +5244,7 @@ def get_model_tokenizer_llava_next_yi(*args, **kwargs):
52445244
hf_model_id='llava-hf/LLaVA-NeXT-Video-7B-32K-hf')
52455245
@register_model(
52465246
ModelType.llava_next_video_7b_instruct,
5247-
'huangjintao/LLaVA-NeXT-Video-7B-hf',
5247+
'swift/LLaVA-NeXT-Video-7B-hf',
52485248
LoRATM.llama,
52495249
TemplateType.llava_next_video,
52505250
support_flash_attn=True,
@@ -5259,7 +5259,7 @@ def get_model_tokenizer_llava_next_video(*args, **kwargs):
52595259

52605260
@register_model(
52615261
ModelType.llava_next_video_34b_instruct,
5262-
'huangjintao/LLaVA-NeXT-Video-34B-hf',
5262+
'swift/LLaVA-NeXT-Video-34B-hf',
52635263
LoRATM.llama,
52645264
TemplateType.llava_next_video_yi,
52655265
support_flash_attn=True,

0 commit comments

Comments
 (0)