Skip to content

Commit ff25ae0

Browse files
authored
set training_args._frozen=False (#70)
1 parent 45256a0 commit ff25ae0

File tree

4 files changed

+7
-5
lines changed

4 files changed

+7
-5
lines changed

examples/pytorch/llm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
5. supported templates: chatml(qwen), baichuan, chatglm2, llama, openbuddy-llama, default, default-generation
3535

3636
## Prepare the Environment
37-
Experimental environment: A10, 3090, A100, ... (V100 does not support bf16, quantization)
37+
Experimental environment: V100, A10, 3090, A100, ... (V100 does not support bf16, quantization)
3838
```bash
3939
# Installing miniconda
4040
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh

examples/pytorch/llm/README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
5. 支持的对话模板: chatml(qwen), baichuan, chatglm2, llama, openbuddy-llama, default, default-generation
3636

3737
## 准备实验环境
38-
实验环境: A10, 3090, A100均可. (V100不支持bf16, 量化)
38+
实验环境: V100, A10, 3090, A100均可. (V100不支持bf16, 量化)
3939
```bash
4040
# 安装miniconda
4141
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh

examples/pytorch/llm/scripts/qwen_7b_chat/lora_mp_ddp/sft.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ torchrun \
1919
--lora_rank 8 \
2020
--lora_alpha 32 \
2121
--lora_dropout_p 0. \
22-
--lora_target_modules c_attn c_proj \
23-
--gradient_checkpointing true \
22+
--lora_target_modules c_attn \
23+
--gradient_checkpointing false \
2424
--batch_size 1 \
2525
--weight_decay 0. \
2626
--learning_rate 1e-4 \

examples/pytorch/llm/src/llm_sft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class SftArguments:
115115
default=None,
116116
metadata={
117117
'help':
118-
"This parameter is used only when model_type.startswith('qwen-7b')"
118+
"This parameter is used only when model_type.startswith('qwen')"
119119
})
120120

121121
def __post_init__(self):
@@ -316,6 +316,8 @@ def llm_sft(args: SftArguments) -> None:
316316
model.config.use_cache = False
317317
model.enable_input_require_grads()
318318
if is_dist():
319+
# Compatible with https://github.com/huggingface/transformers/pull/25903
320+
training_args._frozen = False
319321
if args.gradient_checkpointing:
320322
training_args.ddp_find_unused_parameters = False
321323
training_args.ddp_broadcast_buffers = False

0 commit comments

Comments
 (0)