Skip to content

Commit 48db6d4

Browse files
authored
Move LLaMA tipc benchmark (#6689)
* changes * add tokens and label shift * styles * move benchmark
1 parent dad7b3b commit 48db6d4

File tree

3 files changed

+378
-10
lines changed

3 files changed

+378
-10
lines changed

llm/llama/benchmark.py

Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import json
17+
import os
18+
import sys
19+
from dataclasses import dataclass, field
20+
from functools import partial
21+
22+
import paddle
23+
from modeling_pp import LlamaForCausalLMPipe
24+
from utils import LlamaTrainer, compute_metrics, compute_metrics_not_do_generation
25+
26+
from paddlenlp.data import DataCollatorForSeq2Seq
27+
from paddlenlp.datasets import load_dataset
28+
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
29+
from paddlenlp.peft.prefix import llama_postprocess_past_key_value
30+
from paddlenlp.trainer import (
31+
PdArgumentParser,
32+
TrainingArguments,
33+
get_last_checkpoint,
34+
set_seed,
35+
)
36+
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
37+
from paddlenlp.utils.log import logger
38+
39+
40+
@dataclass
41+
class DataArgument:
42+
data_name: str = field(default=None, metadata={"help": "The name of data."})
43+
task_name_or_path: str = field(default=None, metadata={"help": "The name of task."})
44+
src_length: int = field(default=512, metadata={"help": "The max length of source text."})
45+
tgt_length: int = field(default=256, metadata={"help": "The max length of target text."})
46+
47+
48+
@dataclass
49+
class ModelArgument:
50+
model_name_or_path: str = field(
51+
default="facebook/llama-7b", metadata={"help": "Build-in pretrained model name or the path to local model."}
52+
)
53+
label_smoothing: float = field(default=0.1, metadata={"help": "The label smoothing parameter."})
54+
lr_decay_ratio: float = field(default=0.1, metadata={"help": "The ratio for learning rate decrease"})
55+
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
56+
eval_with_do_generation: bool = field(
57+
default=False, metadata={"help": "Evaluate with generation, instead for calc loss."}
58+
)
59+
profiler_options: str = field(
60+
default=None,
61+
metadata={"help": "profiler_options."},
62+
)
63+
# lora
64+
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
65+
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
66+
lora_rank: int = field(default=4, metadata={"help": "Lora attention dimension"})
67+
merge_weights: bool = field(
68+
default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
69+
)
70+
# prefix
71+
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
72+
num_prefix_tokens: int = field(default=10, metadata={"help": "Number of prefix tokens"})
73+
prefix_projection: bool = field(default=False, metadata={"help": "Whether to project the prefix tokens"})
74+
# qat
75+
qat: bool = field(default=False, metadata={"help": "Whether to use QAT technique"})
76+
qat_type: str = field(default="A8W8", metadata={"help": "Quantization type. Supported values: A8W8, W4,A8W4"})
77+
78+
79+
PROMPT_DICT = {
80+
"prompt_input": (
81+
"Below is an instruction that describes a task, paired with an input that provides further context. "
82+
"Write a response that appropriately completes the request.\n\n"
83+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
84+
),
85+
"prompt_no_input": (
86+
"Below is an instruction that describes a task. "
87+
"Write a response that appropriately completes the request.\n\n"
88+
"### Instruction:\n{instruction}\n\n### Response:"
89+
),
90+
}
91+
92+
93+
def read_local_dataset(path):
94+
with open(path, "r", encoding="utf-8") as f:
95+
for line in f:
96+
json_line = json.loads(line)
97+
yield json_line
98+
99+
100+
def custom_instruction_convert_example(example, tokenizer, data_args, is_test=False, model_max_length=512):
101+
"""
102+
Convert an example into necessary features.
103+
"""
104+
105+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
106+
107+
if example.get("input", "") != "":
108+
input_seq = prompt_input.format_map(example)
109+
else:
110+
input_seq = prompt_no_input.format_map(example)
111+
112+
output_seq = example["output"] + tokenizer.eos_token
113+
114+
# To compatible with compile training mode in benchmark, input will be pad to fix length
115+
source_tokenized = tokenizer(
116+
input_seq,
117+
return_tensors="pd",
118+
max_length=model_max_length,
119+
truncation=True,
120+
)
121+
122+
source_input_ids_len = (
123+
source_tokenized["input_ids"].not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum().item()
124+
)
125+
126+
example_tokenized = tokenizer(
127+
input_seq + output_seq,
128+
return_tensors="pd",
129+
max_length=model_max_length,
130+
truncation=True,
131+
)
132+
133+
input_ids = example_tokenized["input_ids"][0]
134+
labels = copy.deepcopy(input_ids)
135+
labels[:source_input_ids_len] = -100
136+
137+
if is_test:
138+
return dict(
139+
input_ids=source_tokenized["input_ids"][0],
140+
labels=labels,
141+
)
142+
143+
# shift labels
144+
input_ids, labels = input_ids[:-1], labels[1:]
145+
146+
return dict(
147+
input_ids=input_ids,
148+
labels=labels,
149+
)
150+
151+
152+
def main():
153+
parser = PdArgumentParser((ModelArgument, DataArgument, TrainingArguments))
154+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
155+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
156+
else:
157+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
158+
159+
data_args.always_pad_to_max_length = training_args.pipeline_parallel_degree > 1
160+
161+
training_args.print_config(model_args, "Model")
162+
training_args.print_config(data_args, "Data")
163+
training_args.tgt_length = data_args.tgt_length
164+
165+
training_args.profiler_options = model_args.profiler_options
166+
setattr(training_args, "label_smoothing", model_args.label_smoothing)
167+
setattr(training_args, "lr_decay_ratio", model_args.lr_decay_ratio)
168+
169+
paddle.set_device(training_args.device)
170+
171+
set_seed(args=training_args)
172+
173+
# Log on each process the small summary:
174+
logger.warning(
175+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
176+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}"
177+
)
178+
179+
# Detecting last checkpoint.
180+
last_checkpoint = None
181+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
182+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
183+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1:
184+
raise ValueError(
185+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
186+
"Use --overwrite_output_dir to overcome."
187+
)
188+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
189+
logger.info(
190+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
191+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
192+
)
193+
194+
# Set the dtype for loading model
195+
dtype = "float32"
196+
if training_args.fp16_opt_level == "O2":
197+
if training_args.fp16:
198+
dtype = "float16"
199+
if training_args.bf16:
200+
dtype = "bfloat16"
201+
202+
model_class = AutoModelForCausalLM
203+
if training_args.pipeline_parallel_degree > 1:
204+
if model_args.eval_with_do_generation and training_args.do_eval:
205+
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
206+
model_class = LlamaForCausalLMPipe
207+
208+
# Load the pretrained language model.
209+
model = model_class.from_pretrained(
210+
model_args.model_name_or_path,
211+
tensor_parallel_output=False,
212+
tensor_parallel_degree=training_args.tensor_parallel_degree,
213+
tensor_parallel_rank=training_args.tensor_parallel_rank,
214+
use_flash_attention=model_args.use_flash_attention,
215+
dtype=dtype, # todo enable set dtype to avoid additional mem usage
216+
)
217+
if model_args.lora:
218+
if model_args.lora_path is None:
219+
# Not yet support RowParallelLinear
220+
target_modules = [
221+
".*q_proj.*",
222+
".*v_proj.*",
223+
".*k_proj.*",
224+
".*gate_proj.*",
225+
".*up_proj.*",
226+
".*o_proj.*",
227+
".*down_proj.*",
228+
]
229+
230+
lora_config = LoRAConfig(
231+
target_modules=target_modules,
232+
r=model_args.lora_rank,
233+
lora_alpha=2 * model_args.lora_rank,
234+
merge_weights=model_args.merge_weights,
235+
tensor_parallel_degree=training_args.tensor_parallel_degree,
236+
dtype=dtype,
237+
)
238+
model = LoRAModel(model, lora_config)
239+
else:
240+
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)
241+
242+
model.mark_only_lora_as_trainable()
243+
model.print_trainable_parameters()
244+
245+
if model_args.qat:
246+
from paddle import nn
247+
from paddle.quantization import QAT, QuantConfig
248+
249+
# FakeQuanterChannelWiseAbsMaxObserver not yet merge in Paddle develop
250+
from paddle.quantization.quanters import FakeQuanterChannelWiseAbsMaxObserver
251+
from paddle.quantization.quanters.abs_max import (
252+
FakeQuanterWithAbsMaxObserverLayer,
253+
)
254+
from paddleslim.quant.quanters import PACTQuanter
255+
256+
# from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
257+
from paddlenlp.peft.lora import LoRALinear
258+
from paddlenlp.peft.lora.lora_quant_layers import QuantedLoRALinear
259+
260+
q_config = QuantConfig(activation=None, weight=None)
261+
q_config.add_qat_layer_mapping(LoRALinear, QuantedLoRALinear)
262+
263+
if model_args.qat_type == "A8W8":
264+
activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserverLayer, init_value=20, dtype=dtype)
265+
# activation = FakeQuanterWithAbsMaxObserver(moving_rate=0.9, bit_length=8, dtype=dtype)
266+
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype="float32")
267+
elif model_args.qat_type == "W4":
268+
activation = None
269+
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype="float32")
270+
elif model_args.qat_type == "A8W4":
271+
activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserverLayer, init_value=20, dtype=dtype)
272+
# activation = FakeQuanterWithAbsMaxObserver(moving_rate=0.9, bit_length=8, dtype=dtype)
273+
weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype="float32")
274+
else:
275+
raise ValueError("qat_type should be one of ['A8W8', 'W4', 'A8W4']")
276+
277+
q_config.add_type_config(LoRALinear, weight=weight, activation=activation)
278+
q_config.add_type_config(nn.Linear, weight=weight, activation=activation)
279+
280+
qat = QAT(q_config)
281+
model = qat.quantize(model, inplace=True)
282+
283+
if model_args.prefix_tuning:
284+
prefix_config = PrefixConfig(
285+
num_prefix_tokens=model_args.num_prefix_tokens,
286+
num_attention_heads=model.config.n_head,
287+
num_hidden_layers=model.config.n_layer,
288+
hidden_size=model.config.hidden_size,
289+
prefix_projection=model_args.prefix_projection,
290+
prefix_projection_hidden_size=model.config.hidden_size,
291+
dtype=dtype,
292+
)
293+
model = PrefixModelForCausalLM(
294+
model=model,
295+
prefix_config=prefix_config,
296+
postprocess_past_key_value=llama_postprocess_past_key_value,
297+
)
298+
model.mark_only_prefix_as_trainable()
299+
model.print_trainable_parameters()
300+
301+
tokenizer = AutoTokenizer.from_pretrained(
302+
model_args.model_name_or_path,
303+
padding_side="left", # Allow batch inference
304+
)
305+
tokenizer.pad_token = tokenizer.unk_token
306+
307+
# Load the dataset.
308+
train_ds = load_dataset(read_local_dataset, path="./data/train.txt", lazy=False)
309+
training_args.do_eval = False
310+
data_args.always_pad_to_max_length = True
311+
trans_func = partial(custom_instruction_convert_example, tokenizer=tokenizer, data_args=data_args)
312+
313+
train_ds = train_ds.map(partial(trans_func))
314+
315+
model_max_length = 512
316+
collate_fn = DataCollatorForSeq2Seq(
317+
return_tensors="pd",
318+
tokenizer=tokenizer,
319+
max_length=model_max_length if data_args.always_pad_to_max_length else -1,
320+
padding="max_length" if data_args.always_pad_to_max_length else True,
321+
max_label_length=model_max_length if data_args.always_pad_to_max_length else None,
322+
return_attention_mask=True,
323+
)
324+
325+
def compute_metrics_trainer(eval_preds, tokenizer):
326+
all_preds = []
327+
all_labels = []
328+
preds = eval_preds.predictions
329+
preds = [x[x != -100] for x in preds]
330+
all_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=False))
331+
labels = [x[x != -100] for x in eval_preds.label_ids]
332+
all_labels.extend(tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False))
333+
334+
all_preds = [pred.strip() for pred in all_preds]
335+
all_labels = [label.strip() for label in all_labels]
336+
all_preds = [pred.strip("question:") for pred in all_preds]
337+
all_labels = [label.strip("question:") for label in all_labels]
338+
339+
eval_result = compute_metrics(all_preds, all_labels)
340+
return eval_result
341+
342+
compute_metrics_func = partial(
343+
compute_metrics_trainer,
344+
tokenizer=tokenizer,
345+
)
346+
347+
trainer = LlamaTrainer(
348+
model=model,
349+
args=training_args,
350+
train_dataset=train_ds if training_args.do_train else None,
351+
tokenizer=tokenizer,
352+
compute_metrics=compute_metrics_func
353+
if model_args.eval_with_do_generation
354+
else compute_metrics_not_do_generation,
355+
do_generation=model_args.eval_with_do_generation,
356+
data_collator=collate_fn,
357+
)
358+
359+
if training_args.do_train:
360+
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
361+
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
362+
trainer.log_metrics("train", train_result.metrics)
363+
trainer.save_metrics("train", train_result.metrics)
364+
trainer.save_state()
365+
366+
if training_args.do_eval:
367+
eval_result = trainer.evaluate()
368+
trainer.log_metrics("test", eval_result)
369+
370+
371+
if __name__ == "__main__":
372+
main()

llm/llama/utils.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,11 @@ def on_epoch_end(self, args, state, control, **kwargs):
113113
class LlamaTrainer(Trainer):
114114
def __init__(self, do_generation: bool, **kwargs):
115115
super().__init__(**kwargs)
116-
if self.args.benchmark or self.args.profiler_options is not None:
117-
self.add_callback(
118-
BenchmarkCallback(benchmark=self.args.benchmark, profiler_options=self.args.profiler_options)
119-
)
120-
if self.args.benchmark:
121-
if self.args.disable_tqdm:
122-
self.pop_callback(PrinterCallback)
123-
else:
124-
self.pop_callback(ProgressCallback)
116+
self.add_callback(BenchmarkCallback(benchmark=True, profiler_options=self.args.profiler_options))
117+
if self.args.disable_tqdm:
118+
self.pop_callback(PrinterCallback)
119+
else:
120+
self.pop_callback(ProgressCallback)
125121
self.do_generation = do_generation
126122

127123
def prediction_step(

0 commit comments

Comments
 (0)