Skip to content

Commit 1a8819b

Browse files
authored
fix bug: internlm-20b (#75)
1 parent 8e51ccc commit 1a8819b

File tree

6 files changed

+27
-31
lines changed

6 files changed

+27
-31
lines changed

examples/pytorch/llm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
4. chatglm2 series: chatglm2-6b, chatglm2-6b-32k
2424
5. llama series: llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat
2525
6. openbuddy-llama series: openbuddy-llama2-13b, openbuddy-llama-65b, openbuddy-llama2-70b
26-
7. internlm series: internlm-7b, internlm-7b-chat, internlm-7b-chat-8k
26+
7. internlm series: internlm-7b, internlm-7b-chat, internlm-7b-chat-8k, internlm-20b, internlm-20-chat
2727
8. other: polylm-13b, seqgpt-560m
2828
3. supported features: quantization, DDP, model parallelism(device map), gradient checkpointing, gradient accumulation, pushing to modelscope hub, custom datasets, multimodal and agent SFT, mutli-round chat, ...
2929
4. supported datasets:

examples/pytorch/llm/README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
4. chatglm2 系列: chatglm2-6b, chatglm2-6b-32k
2525
5. llama 系列: llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat
2626
6. openbuddy-llama 系列: openbuddy-llama2-13b, openbuddy-llama-65b, openbuddy-llama2-70b
27-
7. internlm 系列: internlm-7b, internlm-7b-chat, internlm-7b-chat-8k
27+
7. internlm 系列: internlm-7b, internlm-7b-chat, internlm-7b-chat-8k, internlm-20b, internlm-20-chat
2828
8. other: polylm-13b, seqgpt-560m
2929
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpointing, 梯度累加, 支持推送ModelScope Hub, 自定义数据集, 多模态和Agent SFT, 多轮对话, ...
3030
4. 支持的数据集:

examples/pytorch/llm/scripts/internlm_20b/qlora_ddp/infer.sh

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
CUDA_VISIBLE_DEVICES=0 \
2+
python src/llm_infer.py \
3+
--model_type internlm-20b-chat \
4+
--sft_type lora \
5+
--template_type internlm \
6+
--dtype bf16 \
7+
--ckpt_dir "output/internlm-20b-chat/vx_xxx/checkpoint-xxx" \
8+
--eval_human false \
9+
--dataset damo-agent-mini-zh \
10+
--max_length 4096 \
11+
--max_new_tokens 2048 \
12+
--temperature 0.9 \
13+
--top_k 20 \
14+
--top_p 0.9 \
15+
--do_sample true \
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,26 @@
1-
# Experimental environment: 2 * A10
2-
# 2 * 20GB GPU memory
1+
# Experimental environment: 2 * A100
2+
# 2 * 60GB GPU memory
33
nproc_per_node=2
44
CUDA_VISIBLE_DEVICES=0,1 \
55
torchrun \
66
--nproc_per_node=$nproc_per_node \
77
--master_port 29500 \
88
src/llm_sft.py \
9-
--model_type internlm-20b \
9+
--model_type internlm-20b-chat \
1010
--sft_type lora \
11-
--template_type default-generation \
11+
--template_type internlm \
1212
--dtype bf16 \
1313
--output_dir output \
1414
--ddp_backend nccl \
15-
--dataset cmnli-zh \
15+
--dataset damo-agent-mini-zh \
1616
--train_dataset_sample 20000 \
1717
--num_train_epochs 1 \
18-
--max_length 2048 \
18+
--max_length 4096 \
1919
--lora_rank 8 \
2020
--lora_alpha 32 \
2121
--lora_dropout_p 0. \
2222
--lora_target_modules ALL \
23-
--quantization_bit 4 \
24-
--bnb_4bit_comp_dtype bf16 \
25-
--gradient_checkpointing false \
23+
--gradient_checkpointing true \
2624
--batch_size 1 \
2725
--weight_decay 0. \
2826
--learning_rate 1e-4 \
@@ -34,6 +32,6 @@ torchrun \
3432
--save_total_limit 2 \
3533
--logging_steps 10 \
3634
--push_to_hub false \
37-
--hub_model_id internlm-20b-lora \
35+
--hub_model_id internlm-20b-chat-lora \
3836
--hub_private_repo true \
3937
--hub_token 'your-sdk-token' \

examples/pytorch/llm/src/llm_sft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def llm_sft(args: SftArguments) -> None:
9090
args.dataset.split(','), args.dataset_test_ratio,
9191
args.dataset_split_seed)
9292
if args.train_dataset_sample >= 0:
93-
val_dataset_sample = int(args.train_dataset_sample
94-
* args.dataset_test_ratio)
93+
val_dataset_sample = max(
94+
int(args.train_dataset_sample * args.dataset_test_ratio), 1)
9595
train_idxs = np.random.permutation(args.train_dataset_sample)
9696
train_dataset = train_dataset.select(train_idxs)
9797
if val_dataset.shape[0] > val_dataset_sample:

0 commit comments

Comments
 (0)