Skip to content

Commit 0e58759

Browse files
authored
add llama bench (#6119)
1 parent e274c44 commit 0e58759

File tree

5 files changed

+245
-25
lines changed

5 files changed

+245
-25
lines changed

examples/language_model/llama/data.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import json
1617
from dataclasses import dataclass
1718
from typing import Dict, List
1819

@@ -22,6 +23,26 @@
2223

2324
IGNORE_INDEX = -100
2425

26+
PROMPT_DICT = {
27+
"prompt_input": (
28+
"Below is an instruction that describes a task, paired with an input that provides further context. "
29+
"Write a response that appropriately completes the request.\n\n"
30+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
31+
),
32+
"prompt_no_input": (
33+
"Below is an instruction that describes a task. "
34+
"Write a response that appropriately completes the request.\n\n"
35+
"### Instruction:\n{instruction}\n\n### Response:"
36+
),
37+
}
38+
39+
40+
def reader(data_path):
41+
with open(data_path, "r", encoding="utf-8") as f:
42+
for line in f:
43+
json_line = json.loads(line)
44+
yield json_line
45+
2546

2647
def convert_example(example, tokenizer, data_args, is_test=False):
2748
"""
@@ -81,40 +102,57 @@ def convert_example(example, tokenizer, data_args, is_test=False):
81102
)
82103

83104

84-
def custom_instruction_convert_example(example, tokenizer, data_args, is_test=False):
105+
def custom_instruction_convert_example(
106+
example, tokenizer, data_args, is_test=False, benchmark=False, model_max_length=512
107+
):
85108
"""
86109
Convert an example into necessary features.
87110
"""
88111

89-
instruction = ""
90-
input = ""
91-
output = ""
92-
if "instruction" in example and "output" in example:
93-
instruction = example["instruction"]
94-
output = example["output"]
95-
else:
96-
assert False, "instruction and output are not in the input dictionary."
97-
if "input" in example["input"]:
98-
input = example["input"]
112+
if benchmark:
113+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
99114

100-
input_seq = instruction + input
101-
output_seq = output
115+
if example.get("input", "") != "":
116+
input_seq = prompt_input.format_map(example)
117+
else:
118+
input_seq = prompt_no_input.format_map(example)
102119

120+
output_seq = example["output"]
121+
else:
122+
instruction = ""
123+
input = ""
124+
output = ""
125+
if "instruction" in example and "output" in example:
126+
instruction = example["instruction"]
127+
output = example["output"]
128+
else:
129+
assert False, "instruction and output are not in the input dictionary."
130+
if "input" in example["input"]:
131+
input = example["input"]
132+
133+
input_seq = instruction + input
134+
output_seq = output
135+
136+
# To compatible with compile training mode in benchmark, input will be pad to fix length
103137
source_tokenized = tokenizer(
104138
input_seq,
105139
return_tensors="pd",
106-
max_length=data_args.src_length,
140+
padding="loggest" if not benchmark else "max_length",
141+
max_length=data_args.src_length if not benchmark else model_max_length,
107142
truncation=True,
108143
)
109144

110145
source_input_ids_len = (
111146
source_tokenized["input_ids"].not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum().item()
112147
)
113148

149+
total_length = data_args.src_length + data_args.tgt_length
150+
114151
example_tokenized = tokenizer(
115152
input_seq + output_seq,
116153
return_tensors="pd",
117-
max_length=data_args.src_length + data_args.tgt_length,
154+
padding="loggest" if not benchmark else "max_length",
155+
max_length=total_length if not benchmark else model_max_length,
118156
truncation=True,
119157
)
120158

@@ -134,7 +172,7 @@ def custom_instruction_convert_example(example, tokenizer, data_args, is_test=Fa
134172
)
135173

136174

137-
def left_padding(inputs, pad_id, max_length=0):
175+
def left_padding(inputs, pad_id, max_length=-1):
138176
for ids in inputs:
139177
max_length = max(max_length, len(ids))
140178

@@ -156,7 +194,7 @@ class DataCollatorForSupervisedDataset(object):
156194
"""Collate examples for supervised fine-tuning."""
157195

158196
tokenizer: PretrainedTokenizerBase
159-
max_length: 0
197+
max_length: -1
160198

161199
def __call__(self, features: List[Dict]) -> Dict[str, paddle.Tensor]:
162200

examples/language_model/llama/finetune_instruction_generation.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
import numpy as np
2020
import paddle
21-
from data import DataCollatorForSupervisedDataset, custom_instruction_convert_example
21+
from data import (
22+
DataCollatorForSupervisedDataset,
23+
custom_instruction_convert_example,
24+
reader,
25+
)
2226
from sklearn.metrics import accuracy_score
2327
from utils import LlamaTrainer, compute_metrics, save_infer_result
2428

@@ -57,6 +61,14 @@ class ModelArgument:
5761
prefix_projection: bool = field(default=True, metadata={"help": "Whether to project the prefix tokens"})
5862
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
5963
do_generation: bool = field(default=False, metadata={"help": "Whether to do generation for evaluation"})
64+
benchmark: bool = field(
65+
default=False,
66+
metadata={"help": "Whether or not run benchmark."},
67+
)
68+
profiler_options: str = field(
69+
default=None,
70+
metadata={"help": "profiler_options."},
71+
)
6072

6173

6274
def main():
@@ -65,6 +77,8 @@ def main():
6577

6678
training_args.print_config(model_args, "Model")
6779
training_args.print_config(data_args, "Data")
80+
training_args.benchmark = model_args.benchmark
81+
training_args.profiler_options = model_args.profiler_options
6882
setattr(training_args, "label_smoothing", model_args.label_smoothing)
6983
setattr(training_args, "lr_decay_ratio", model_args.lr_decay_ratio)
7084

@@ -115,6 +129,7 @@ def main():
115129
model_args.model_name_or_path,
116130
padding_side="left", # Allow batch inference
117131
)
132+
tokenizer.pad_token = tokenizer.unk_token
118133

119134
if model_args.lora:
120135
# TODO: hardcode parameters for now. Change after MergedLoRA is introduced
@@ -149,12 +164,20 @@ def main():
149164
model.print_trainable_parameters()
150165

151166
# Load the dataset.
152-
train_ds, dev_ds = load_dataset(data_args.data_name, data_args.task_name, splits=["train", "dev"])
153-
154-
trans_func = partial(custom_instruction_convert_example, tokenizer=tokenizer, data_args=data_args)
167+
if training_args.benchmark:
168+
train_ds = load_dataset(reader, data_path="./data/train.txt", lazy=False)
169+
dev_ds = None
170+
else:
171+
train_ds, dev_ds = load_dataset(data_args.data_name, data_args.task_name, splits=["train", "dev"])
172+
173+
trans_func = partial(
174+
custom_instruction_convert_example, tokenizer=tokenizer, data_args=data_args, benchmark=training_args.benchmark
175+
)
155176
train_ds = train_ds.map(partial(trans_func))
156-
dev_ds = dev_ds.map(partial(trans_func))
157-
collate_fn = DataCollatorForSupervisedDataset(tokenizer)
177+
178+
if not training_args.benchmark:
179+
dev_ds = dev_ds.map(partial(trans_func))
180+
collate_fn = DataCollatorForSupervisedDataset(tokenizer, max_length=-1)
158181

159182
def compute_metrics_trainer(eval_preds, tokenizer):
160183
all_preds = []

examples/language_model/llama/utils.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import json
1616
import os
17+
import time
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

1920
import numpy as np
@@ -24,7 +25,9 @@
2425
from rouge import Rouge
2526

2627
from paddlenlp.metrics import BLEU
27-
from paddlenlp.trainer import Trainer
28+
from paddlenlp.trainer import PrinterCallback, ProgressCallback, Trainer
29+
from paddlenlp.trainer.integrations import TrainerCallback
30+
from paddlenlp.utils.log import logger
2831

2932

3033
def save_infer_result(trainer, dev_ds, k=100, src_length=256, tgt_length=512):
@@ -61,9 +64,100 @@ def save_infer_result(trainer, dev_ds, k=100, src_length=256, tgt_length=512):
6164
f.write(json.dumps(out, ensure_ascii=False) + "\n")
6265

6366

67+
class AverageStatistical(object):
68+
def __init__(self):
69+
self.reset()
70+
71+
def reset(self):
72+
self.total_cnt = 0
73+
self.time = 0
74+
75+
def record(self, val, cnt=1):
76+
self.time += val
77+
self.total_cnt += cnt
78+
79+
def get_average(self):
80+
if self.total_cnt == 0:
81+
return 0
82+
83+
return self.time / self.total_cnt
84+
85+
def get_average_per_sec(self):
86+
if self.time == 0.0:
87+
return 0.0
88+
89+
return float(self.total_cnt) / self.time
90+
91+
def get_total_cnt(self):
92+
return self.total_cnt
93+
94+
def get_total_time(self):
95+
return self.time
96+
97+
98+
class BenchmarkCallback(TrainerCallback):
99+
def __init__(self, benchmark=True, profiler_options=None):
100+
self.benchmark = benchmark
101+
self.profiler_options = profiler_options
102+
103+
def on_train_begin(self, args, state, control, **kwargs):
104+
assert args.gradient_accumulation_steps == 1 and not args.do_eval and not args.do_predict
105+
if self.benchmark:
106+
self.reader_cost_avg = AverageStatistical()
107+
108+
def on_epoch_begin(self, args, state, control, **kwargs):
109+
if self.benchmark:
110+
self.epoch_start = time.time()
111+
self.batch_start = time.time()
112+
113+
def on_step_begin(self, args, state, control, **kwargs):
114+
if self.benchmark:
115+
self.reader_cost_avg.record(time.time() - self.batch_start)
116+
117+
def on_step_end(self, args, state, control, **kwargs):
118+
if self.benchmark:
119+
self.batch_start = time.time()
120+
if control.should_log:
121+
self.maybe_log_save_evaluate_start = time.time()
122+
123+
def on_log(self, args, state, control, logs=None, **kwargs):
124+
if self.benchmark:
125+
if logs is not None and "interval_steps_per_second" in logs:
126+
self.batch_start = self.batch_start + (time.time() - self.maybe_log_save_evaluate_start)
127+
ips = logs["interval_steps_per_second"] * args.train_batch_size
128+
avg_batch_cost = 1 / logs["interval_steps_per_second"]
129+
logger.info(
130+
"global step %d / %d, loss: %f, avg_reader_cost: %.5f sec, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sample/sec"
131+
% (
132+
state.global_step,
133+
state.max_steps,
134+
logs["loss"],
135+
self.reader_cost_avg.get_average(),
136+
avg_batch_cost,
137+
args.train_batch_size,
138+
ips,
139+
)
140+
)
141+
self.reader_cost_avg.reset()
142+
143+
def on_epoch_end(self, args, state, control, **kwargs):
144+
if self.benchmark:
145+
train_epoch_cost = time.time() - self.epoch_start
146+
logger.info("train epoch: %d, epoch_cost: %.5f s" % (state.epoch, train_epoch_cost))
147+
148+
64149
class LlamaTrainer(Trainer):
65150
def __init__(self, do_generation: bool, **kwargs):
66151
super().__init__(**kwargs)
152+
if self.args.benchmark or self.args.profiler_options is not None:
153+
self.add_callback(
154+
BenchmarkCallback(benchmark=self.args.benchmark, profiler_options=self.args.profiler_options)
155+
)
156+
if self.args.benchmark:
157+
if self.args.disable_tqdm:
158+
self.pop_callback(PrinterCallback)
159+
else:
160+
self.pop_callback(ProgressCallback)
67161
self.do_generation = do_generation
68162

69163
def prediction_step(
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
===========================train_params===========================
2+
model_name:llama
3+
python:python3.7
4+
gpu_list:0|0,1
5+
--device:gpu|gpu
6+
--fp16:null
7+
--max_steps:null
8+
null:null
9+
--per_device_train_batch_size:null
10+
null:null
11+
null:null
12+
null:null
13+
null:null
14+
##
15+
trainer:norm_train
16+
norm_train:../examples/language_model/llama/finetune_instruction_generation.py --model_name_or_path facebook/llama-7b-2l --do_train --max_steps 500 --recompute False --benchmark --overwrite_output_dir --output_dir ./checkpoints/ --fp16_opt_level O2 --learning_rate 3e-5 --lr_scheduler_type constant --warmup_steps 0 --seed 23 --logging_steps 1 --max_grad_norm -1
17+
pact_train:null
18+
fpgm_train:null
19+
distill_train:null
20+
null:null
21+
null:null
22+
##
23+
===========================eval_params===========================
24+
eval:null
25+
null:null
26+
##
27+
===========================infer_params===========================
28+
null:null
29+
null:null
30+
norm_export:null
31+
quant_export:null
32+
fpgm_export:null
33+
distill_export:null
34+
export1:null
35+
export2:null
36+
##
37+
infer_model:null
38+
infer_export:null
39+
infer_quant:null
40+
inference:null
41+
null:null
42+
null:null
43+
null:null
44+
null:null
45+
null:null
46+
null:null
47+
null:null
48+
null:null
49+
null:null
50+
null:null
51+
null:null
52+
===========================to_static_train_benchmark_params===========================
53+
to_static_train:--to_static
54+
===========================train_benchmark_params==========================
55+
batch_size:8
56+
fp_items:fp32|fp16
57+
epoch:500
58+
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
59+
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096

tests/test_tipc/prepare.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,18 @@ elif [ ${MODE} = "benchmark_train" ];then
400400
tar -zxvf laion400m_demo_data.tar.gz
401401
fi
402402

403+
if [[ ${model_name} =~ "llama" ]]; then
404+
rm -rf llama_sft_demo_data.tar.gz
405+
wget https://paddlenlp.bj.bcebos.com/models/community/facebook/llama_sft_demo_data.tar.gz
406+
tar -xvf llama_sft_demo_data.tar.gz
407+
fi
408+
403409
export PYTHONPATH=$(dirname "$PWD"):$PYTHONPATH
404410
python -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
405411
python -m pip install setuptools_scm
406412
python -m pip install Cython
407413
python -m pip install -r ../requirements.txt #-i https://pypi.tuna.tsinghua.edu.cn/simple
408-
python -m pip install pybind11 regex sentencepiece tqdm visualdl attrdict pyyaml -i https://mirror.baidu.com/pypi/simple
414+
python -m pip install pybind11 regex sentencepiece tqdm visualdl attrdict pyyaml rouge -i https://mirror.baidu.com/pypi/simple
409415

410416
python -m pip install -e ../
411417
# python -m pip install paddlenlp # PDC 镜像中安装失败

0 commit comments

Comments
 (0)