Skip to content

Commit 3fb4d69

Browse files
committed
Support Yi-6b sft (#134)
(cherry picked from commit 0b3f840)
1 parent 0bfc662 commit 3fb4d69

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Experimental environment: A10
2+
PYTHONPATH=../../.. \
3+
CUDA_VISIBLE_DEVICES=0 \
4+
python llm_infer.py \
5+
--ckpt_dir "output/yi-6b/vx_xxx/checkpoint-xxx" \
6+
--load_args_from_ckpt_dir true \
7+
--eval_human false \
8+
--max_length 256 \
9+
--max_new_tokens 256 \
10+
--temperature 0.9 \
11+
--top_k 20 \
12+
--top_p 0.9 \
13+
--repetition_penalty 1.05 \
14+
--do_sample true \
15+
--merge_lora_and_save false \
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Experimental environment: A10
2+
# 15GB GPU memory
3+
PYTHONPATH=../../.. \
4+
CUDA_VISIBLE_DEVICES=0 \
5+
python llm_sft.py \
6+
--model_id_or_path 01ai/Yi-6B \
7+
--model_revision master \
8+
--sft_type lora \
9+
--tuner_backend swift \
10+
--template_type default-generation \
11+
--dtype bf16 \
12+
--output_dir output \
13+
--dataset dureader-robust-zh \
14+
--train_dataset_sample -1 \
15+
--num_train_epochs 1 \
16+
--max_length 2048 \
17+
--check_dataset_strategy warning \
18+
--lora_rank 8 \
19+
--lora_alpha 32 \
20+
--lora_dropout_p 0.05 \
21+
--lora_target_modules ALL \
22+
--gradient_checkpointing true \
23+
--batch_size 1 \
24+
--weight_decay 0.01 \
25+
--learning_rate 1e-4 \
26+
--gradient_accumulation_steps 16 \
27+
--max_grad_norm 0.5 \
28+
--warmup_ratio 0.03 \
29+
--eval_steps 100 \
30+
--save_steps 100 \
31+
--save_total_limit 2 \
32+
--logging_steps 10 \
33+
--push_to_hub false \
34+
--hub_model_id yi-6b-qlora \
35+
--hub_private_repo true \
36+
--hub_token 'your-sdk-token' \

swift/llm/utils/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class ModelType:
9292
# other
9393
polylm_13b = 'polylm-13b'
9494
seqgpt_560m = 'seqgpt-560m'
95+
yi_6b = 'yi-6b'
96+
yi_34b = 'yi-34b'
9597

9698

9799
class LoRATM(NamedTuple):
@@ -106,6 +108,7 @@ class LoRATM(NamedTuple):
106108
xverse = ['q_proj', 'k_proj', 'v_proj']
107109
mistral = ['q_proj', 'k_proj', 'v_proj']
108110
ziya = ['q_proj', 'k_proj', 'v_proj']
111+
yi = ['q_proj', 'k_proj', 'v_proj']
109112

110113

111114
GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel],
@@ -169,6 +172,10 @@ def _register_model(
169172
return _register_model
170173

171174

175+
@register_model(ModelType.yi_34b, '01ai/Yi-34B', LoRATM.yi,
176+
TemplateType.default_generation)
177+
@register_model(ModelType.yi_6b, '01ai/Yi-6B', LoRATM.yi,
178+
TemplateType.default_generation)
172179
@register_model(ModelType.seqgpt_560m, 'damo/nlp_seqgpt-560m', LoRATM.bloom,
173180
TemplateType.default_generation)
174181
@register_model(ModelType.ziya2_13b_chat, 'Fengshenbang/Ziya2-13B-Chat',

swift/llm/utils/template.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class TemplateType:
2222
xverse = 'xverse'
2323
ziya = 'ziya'
2424
skywork = 'skywork'
25+
yi = 'yi'
2526

2627

2728
Prompt = List[Union[str, List[Union[str, int]]]]

0 commit comments

Comments
 (0)