Skip to content

Commit 45f4b88

Browse files
authored
fix template bug2 (#44)
1 parent 9cf5e67 commit 45f4b88

File tree

18 files changed

+40
-26
lines changed

18 files changed

+40
-26
lines changed

examples/pytorch/llm/scripts/qwen_7b/qlora_ddp/infer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CUDA_VISIBLE_DEVICES=0 \
33
python src/llm_infer.py \
44
--model_type qwen-7b \
55
--sft_type lora \
6-
--template_type chatml \
6+
--template_type default \
77
--dtype bf16 \
88
--ckpt_dir "runs/qwen-7b/vx_xxx/checkpoint-xxx" \
99
--eval_human true \

examples/pytorch/llm/scripts/qwen_7b/qlora_ddp/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ torchrun \
1818
--quantization_bit 4 \
1919
--bnb_4bit_comp_dtype bf16 \
2020
--lora_rank 64 \
21-
--lora_alpha 16 \
21+
--lora_alpha 32 \
2222
--lora_dropout_p 0.05 \
2323
--lora_target_modules ALL \
2424
--gradient_checkpointing true \

examples/pytorch/llm/scripts/qwen_7b_chat/qlora/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ python src/llm_sft.py \
1313
--quantization_bit 4 \
1414
--bnb_4bit_comp_dtype bf16 \
1515
--lora_rank 64 \
16-
--lora_alpha 16 \
16+
--lora_alpha 32 \
1717
--lora_dropout_p 0.05 \
1818
--lora_target_modules ALL \
1919
--gradient_checkpointing true \

examples/pytorch/llm/scripts/qwen_7b_chat/qlora_ddp/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ torchrun \
1818
--quantization_bit 4 \
1919
--bnb_4bit_comp_dtype bf16 \
2020
--lora_rank 64 \
21-
--lora_alpha 16 \
21+
--lora_alpha 32 \
2222
--lora_dropout_p 0.05 \
2323
--lora_target_modules ALL \
2424
--gradient_checkpointing true \

examples/pytorch/llm/scripts/qwen_agent/qlora_ddp/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ torchrun \
1818
--quantization_bit 4 \
1919
--bnb_4bit_comp_dtype bf16 \
2020
--lora_rank 64 \
21-
--lora_alpha 16 \
21+
--lora_alpha 32 \
2222
--lora_dropout_p 0.05 \
2323
--lora_target_modules ALL \
2424
--gradient_checkpointing true \

examples/pytorch/llm/scripts/qwen_vl/qlora_ddp/infer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ CUDA_VISIBLE_DEVICES=0 \
33
python src/llm_infer.py \
44
--model_type qwen-vl \
55
--sft_type lora \
6-
--template_type chatml \
6+
--template_type default \
77
--dtype bf16 \
88
--ckpt_dir "runs/qwen-vl/vx_xxx/checkpoint-xxx" \
99
--eval_human false \

examples/pytorch/llm/scripts/qwen_vl/qlora_ddp/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ torchrun \
1818
--quantization_bit 4 \
1919
--bnb_4bit_comp_dtype bf16 \
2020
--lora_rank 64 \
21-
--lora_alpha 16 \
21+
--lora_alpha 32 \
2222
--lora_dropout_p 0.05 \
2323
--lora_target_modules ALL \
2424
--gradient_checkpointing true \

examples/pytorch/llm/scripts/qwen_vl_chat/qlora/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ python src/llm_sft.py \
1313
--quantization_bit 4 \
1414
--bnb_4bit_comp_dtype bf16 \
1515
--lora_rank 64 \
16-
--lora_alpha 16 \
16+
--lora_alpha 32 \
1717
--lora_dropout_p 0.05 \
1818
--lora_target_modules ALL \
1919
--gradient_checkpointing true \

examples/pytorch/llm/scripts/qwen_vl_chat/qlora_ddp/sft.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ torchrun \
1818
--quantization_bit 4 \
1919
--bnb_4bit_comp_dtype bf16 \
2020
--lora_rank 64 \
21-
--lora_alpha 16 \
21+
--lora_alpha 32 \
2222
--lora_dropout_p 0.05 \
2323
--lora_target_modules ALL \
2424
--gradient_checkpointing true \

examples/pytorch/llm/src/llm_infer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
12
import os
23
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
34
from dataclasses import dataclass, field
@@ -102,9 +103,12 @@ def llm_infer(args: InferArguments) -> None:
102103
print_model_info(model)
103104

104105
# ### Inference
105-
template_type = MODEL_MAPPING[args.model_type]['template']
106106
preprocess_func = get_preprocess(
107-
template_type, tokenizer, args.system, args.max_length, batched=False)
107+
args.template_type,
108+
tokenizer,
109+
args.system,
110+
args.max_length,
111+
batched=False)
108112
streamer = TextStreamer(
109113
tokenizer, skip_prompt=True, skip_special_tokens=True)
110114
generation_config = GenerationConfig(

0 commit comments

Comments
 (0)