Skip to content

Commit 8100f2a

Browse files
committed
fit llama and to_static
1 parent 77680b6 commit 8100f2a

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

llm/llama/run.sh

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# export FLAGS_prim_all=true
2+
# export FLAGS_enable_pir_api=True
3+
export FLAGS_cudnn_deterministc=1
4+
export CUDA_VISIBLE_DEVICES=6
5+
6+
if [ ! -d ./data ]
7+
then
8+
mkdir ./data
9+
cd ./data
10+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_ids.npy
11+
wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k_idx.npz
12+
cd ..
13+
fi
14+
15+
if [ -d ./output ]
16+
then
17+
rm -rf ./output
18+
fi
19+
20+
task_name_or_path="llama_output"
21+
python run_pretrain.py \
22+
--model_type "llama" \
23+
--model_name_or_path "__internal_testing__/tiny-random-llama" \
24+
--tokenizer_name_or_path "__internal_testing__/tiny-random-llama" \
25+
--input_dir "./data" \
26+
--output_dir "./output/$task_name_or_path" \
27+
--split 949,50,1 \
28+
--max_seq_length 2048 \
29+
--per_device_train_batch_size 4 \
30+
--per_device_eval_batch_size 1 \
31+
--use_flash_attention 0 \
32+
--use_fused_rms_norm 0 \
33+
--scale_loss 1024 \
34+
--learning_rate 0.00001 \
35+
--min_learning_rate 0.000005 \
36+
--lr_scheduler_type "cosine" \
37+
--max_steps 500 \
38+
--save_steps 500 \
39+
--weight_decay 0.01 \
40+
--warmup_ratio 0.01 \
41+
--max_grad_norm 1.0 \
42+
--logging_steps 1\
43+
--dataloader_num_workers 1 \
44+
--eval_steps 500 \
45+
--report_to "visualdl" \
46+
--disable_tqdm true \
47+
--continue_training 0\
48+
--recompute 0 \
49+
--do_train \
50+
--do_eval \
51+
--device "gpu" \
52+
--seed 2023 \
53+
--use_fused_rms_norm False \
54+
# --fp16 \
55+
# --fp16_opt_level "O2"

paddlenlp/transformers/llama/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ def _expand_2d_mask(mask, dtype, tgt_length):
283283
batch_size, src_length = mask.shape[0], mask.shape[-1]
284284
tgt_length = tgt_length if tgt_length is not None else src_length
285285

286-
mask = mask[:, None, None, :].astype("bool")
286+
# mask = mask[:, None, None, :].astype("bool")
287+
mask = paddle.cast(mask[:, None, None, :], "bool")
287288
mask.stop_gradient = True
288289
expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length])
289290

0 commit comments

Comments
 (0)