|
| 1 | +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import copy |
| 16 | +import json |
| 17 | +import os |
| 18 | +import sys |
| 19 | +from dataclasses import dataclass, field |
| 20 | +from functools import partial |
| 21 | + |
| 22 | +import paddle |
| 23 | +from modeling_pp import LlamaForCausalLMPipe |
| 24 | +from utils import LlamaTrainer, compute_metrics, compute_metrics_not_do_generation |
| 25 | + |
| 26 | +from paddlenlp.data import DataCollatorForSeq2Seq |
| 27 | +from paddlenlp.datasets import load_dataset |
| 28 | +from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM |
| 29 | +from paddlenlp.peft.prefix import llama_postprocess_past_key_value |
| 30 | +from paddlenlp.trainer import ( |
| 31 | + PdArgumentParser, |
| 32 | + TrainingArguments, |
| 33 | + get_last_checkpoint, |
| 34 | + set_seed, |
| 35 | +) |
| 36 | +from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer |
| 37 | +from paddlenlp.utils.log import logger |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class DataArgument: |
| 42 | + data_name: str = field(default=None, metadata={"help": "The name of data."}) |
| 43 | + task_name_or_path: str = field(default=None, metadata={"help": "The name of task."}) |
| 44 | + src_length: int = field(default=512, metadata={"help": "The max length of source text."}) |
| 45 | + tgt_length: int = field(default=256, metadata={"help": "The max length of target text."}) |
| 46 | + |
| 47 | + |
| 48 | +@dataclass |
| 49 | +class ModelArgument: |
| 50 | + model_name_or_path: str = field( |
| 51 | + default="facebook/llama-7b", metadata={"help": "Build-in pretrained model name or the path to local model."} |
| 52 | + ) |
| 53 | + label_smoothing: float = field(default=0.1, metadata={"help": "The label smoothing parameter."}) |
| 54 | + lr_decay_ratio: float = field(default=0.1, metadata={"help": "The ratio for learning rate decrease"}) |
| 55 | + use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"}) |
| 56 | + eval_with_do_generation: bool = field( |
| 57 | + default=False, metadata={"help": "Evaluate with generation, instead for calc loss."} |
| 58 | + ) |
| 59 | + profiler_options: str = field( |
| 60 | + default=None, |
| 61 | + metadata={"help": "profiler_options."}, |
| 62 | + ) |
| 63 | + # lora |
| 64 | + lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"}) |
| 65 | + lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."}) |
| 66 | + lora_rank: int = field(default=4, metadata={"help": "Lora attention dimension"}) |
| 67 | + merge_weights: bool = field( |
| 68 | + default=False, metadata={"help": "Merge weights of the original model and the Lora model"} |
| 69 | + ) |
| 70 | + # prefix |
| 71 | + prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"}) |
| 72 | + num_prefix_tokens: int = field(default=10, metadata={"help": "Number of prefix tokens"}) |
| 73 | + prefix_projection: bool = field(default=False, metadata={"help": "Whether to project the prefix tokens"}) |
| 74 | + # qat |
| 75 | + qat: bool = field(default=False, metadata={"help": "Whether to use QAT technique"}) |
| 76 | + qat_type: str = field(default="A8W8", metadata={"help": "Quantization type. Supported values: A8W8, W4,A8W4"}) |
| 77 | + |
| 78 | + |
| 79 | +PROMPT_DICT = { |
| 80 | + "prompt_input": ( |
| 81 | + "Below is an instruction that describes a task, paired with an input that provides further context. " |
| 82 | + "Write a response that appropriately completes the request.\n\n" |
| 83 | + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" |
| 84 | + ), |
| 85 | + "prompt_no_input": ( |
| 86 | + "Below is an instruction that describes a task. " |
| 87 | + "Write a response that appropriately completes the request.\n\n" |
| 88 | + "### Instruction:\n{instruction}\n\n### Response:" |
| 89 | + ), |
| 90 | +} |
| 91 | + |
| 92 | + |
| 93 | +def read_local_dataset(path): |
| 94 | + with open(path, "r", encoding="utf-8") as f: |
| 95 | + for line in f: |
| 96 | + json_line = json.loads(line) |
| 97 | + yield json_line |
| 98 | + |
| 99 | + |
| 100 | +def custom_instruction_convert_example(example, tokenizer, data_args, is_test=False, model_max_length=512): |
| 101 | + """ |
| 102 | + Convert an example into necessary features. |
| 103 | + """ |
| 104 | + |
| 105 | + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] |
| 106 | + |
| 107 | + if example.get("input", "") != "": |
| 108 | + input_seq = prompt_input.format_map(example) |
| 109 | + else: |
| 110 | + input_seq = prompt_no_input.format_map(example) |
| 111 | + |
| 112 | + output_seq = example["output"] + tokenizer.eos_token |
| 113 | + |
| 114 | + # To compatible with compile training mode in benchmark, input will be pad to fix length |
| 115 | + source_tokenized = tokenizer( |
| 116 | + input_seq, |
| 117 | + return_tensors="pd", |
| 118 | + max_length=model_max_length, |
| 119 | + truncation=True, |
| 120 | + ) |
| 121 | + |
| 122 | + source_input_ids_len = ( |
| 123 | + source_tokenized["input_ids"].not_equal(paddle.to_tensor(tokenizer.pad_token_id)).sum().item() |
| 124 | + ) |
| 125 | + |
| 126 | + example_tokenized = tokenizer( |
| 127 | + input_seq + output_seq, |
| 128 | + return_tensors="pd", |
| 129 | + max_length=model_max_length, |
| 130 | + truncation=True, |
| 131 | + ) |
| 132 | + |
| 133 | + input_ids = example_tokenized["input_ids"][0] |
| 134 | + labels = copy.deepcopy(input_ids) |
| 135 | + labels[:source_input_ids_len] = -100 |
| 136 | + |
| 137 | + if is_test: |
| 138 | + return dict( |
| 139 | + input_ids=source_tokenized["input_ids"][0], |
| 140 | + labels=labels, |
| 141 | + ) |
| 142 | + |
| 143 | + # shift labels |
| 144 | + input_ids, labels = input_ids[:-1], labels[1:] |
| 145 | + |
| 146 | + return dict( |
| 147 | + input_ids=input_ids, |
| 148 | + labels=labels, |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +def main(): |
| 153 | + parser = PdArgumentParser((ModelArgument, DataArgument, TrainingArguments)) |
| 154 | + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| 155 | + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| 156 | + else: |
| 157 | + model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 158 | + |
| 159 | + data_args.always_pad_to_max_length = training_args.pipeline_parallel_degree > 1 |
| 160 | + |
| 161 | + training_args.print_config(model_args, "Model") |
| 162 | + training_args.print_config(data_args, "Data") |
| 163 | + training_args.tgt_length = data_args.tgt_length |
| 164 | + |
| 165 | + training_args.profiler_options = model_args.profiler_options |
| 166 | + setattr(training_args, "label_smoothing", model_args.label_smoothing) |
| 167 | + setattr(training_args, "lr_decay_ratio", model_args.lr_decay_ratio) |
| 168 | + |
| 169 | + paddle.set_device(training_args.device) |
| 170 | + |
| 171 | + set_seed(args=training_args) |
| 172 | + |
| 173 | + # Log on each process the small summary: |
| 174 | + logger.warning( |
| 175 | + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " |
| 176 | + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" |
| 177 | + ) |
| 178 | + |
| 179 | + # Detecting last checkpoint. |
| 180 | + last_checkpoint = None |
| 181 | + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
| 182 | + last_checkpoint = get_last_checkpoint(training_args.output_dir) |
| 183 | + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 1: |
| 184 | + raise ValueError( |
| 185 | + f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
| 186 | + "Use --overwrite_output_dir to overcome." |
| 187 | + ) |
| 188 | + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
| 189 | + logger.info( |
| 190 | + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
| 191 | + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
| 192 | + ) |
| 193 | + |
| 194 | + # Set the dtype for loading model |
| 195 | + dtype = "float32" |
| 196 | + if training_args.fp16_opt_level == "O2": |
| 197 | + if training_args.fp16: |
| 198 | + dtype = "float16" |
| 199 | + if training_args.bf16: |
| 200 | + dtype = "bfloat16" |
| 201 | + |
| 202 | + model_class = AutoModelForCausalLM |
| 203 | + if training_args.pipeline_parallel_degree > 1: |
| 204 | + if model_args.eval_with_do_generation and training_args.do_eval: |
| 205 | + raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.") |
| 206 | + model_class = LlamaForCausalLMPipe |
| 207 | + |
| 208 | + # Load the pretrained language model. |
| 209 | + model = model_class.from_pretrained( |
| 210 | + model_args.model_name_or_path, |
| 211 | + tensor_parallel_output=False, |
| 212 | + tensor_parallel_degree=training_args.tensor_parallel_degree, |
| 213 | + tensor_parallel_rank=training_args.tensor_parallel_rank, |
| 214 | + use_flash_attention=model_args.use_flash_attention, |
| 215 | + dtype=dtype, # todo enable set dtype to avoid additional mem usage |
| 216 | + ) |
| 217 | + if model_args.lora: |
| 218 | + if model_args.lora_path is None: |
| 219 | + # Not yet support RowParallelLinear |
| 220 | + target_modules = [ |
| 221 | + ".*q_proj.*", |
| 222 | + ".*v_proj.*", |
| 223 | + ".*k_proj.*", |
| 224 | + ".*gate_proj.*", |
| 225 | + ".*up_proj.*", |
| 226 | + ".*o_proj.*", |
| 227 | + ".*down_proj.*", |
| 228 | + ] |
| 229 | + |
| 230 | + lora_config = LoRAConfig( |
| 231 | + target_modules=target_modules, |
| 232 | + r=model_args.lora_rank, |
| 233 | + lora_alpha=2 * model_args.lora_rank, |
| 234 | + merge_weights=model_args.merge_weights, |
| 235 | + tensor_parallel_degree=training_args.tensor_parallel_degree, |
| 236 | + dtype=dtype, |
| 237 | + ) |
| 238 | + model = LoRAModel(model, lora_config) |
| 239 | + else: |
| 240 | + model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path) |
| 241 | + |
| 242 | + model.mark_only_lora_as_trainable() |
| 243 | + model.print_trainable_parameters() |
| 244 | + |
| 245 | + if model_args.qat: |
| 246 | + from paddle import nn |
| 247 | + from paddle.quantization import QAT, QuantConfig |
| 248 | + |
| 249 | + # FakeQuanterChannelWiseAbsMaxObserver not yet merge in Paddle develop |
| 250 | + from paddle.quantization.quanters import FakeQuanterChannelWiseAbsMaxObserver |
| 251 | + from paddle.quantization.quanters.abs_max import ( |
| 252 | + FakeQuanterWithAbsMaxObserverLayer, |
| 253 | + ) |
| 254 | + from paddleslim.quant.quanters import PACTQuanter |
| 255 | + |
| 256 | + # from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver |
| 257 | + from paddlenlp.peft.lora import LoRALinear |
| 258 | + from paddlenlp.peft.lora.lora_quant_layers import QuantedLoRALinear |
| 259 | + |
| 260 | + q_config = QuantConfig(activation=None, weight=None) |
| 261 | + q_config.add_qat_layer_mapping(LoRALinear, QuantedLoRALinear) |
| 262 | + |
| 263 | + if model_args.qat_type == "A8W8": |
| 264 | + activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserverLayer, init_value=20, dtype=dtype) |
| 265 | + # activation = FakeQuanterWithAbsMaxObserver(moving_rate=0.9, bit_length=8, dtype=dtype) |
| 266 | + weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype="float32") |
| 267 | + elif model_args.qat_type == "W4": |
| 268 | + activation = None |
| 269 | + weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype="float32") |
| 270 | + elif model_args.qat_type == "A8W4": |
| 271 | + activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserverLayer, init_value=20, dtype=dtype) |
| 272 | + # activation = FakeQuanterWithAbsMaxObserver(moving_rate=0.9, bit_length=8, dtype=dtype) |
| 273 | + weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype="float32") |
| 274 | + else: |
| 275 | + raise ValueError("qat_type should be one of ['A8W8', 'W4', 'A8W4']") |
| 276 | + |
| 277 | + q_config.add_type_config(LoRALinear, weight=weight, activation=activation) |
| 278 | + q_config.add_type_config(nn.Linear, weight=weight, activation=activation) |
| 279 | + |
| 280 | + qat = QAT(q_config) |
| 281 | + model = qat.quantize(model, inplace=True) |
| 282 | + |
| 283 | + if model_args.prefix_tuning: |
| 284 | + prefix_config = PrefixConfig( |
| 285 | + num_prefix_tokens=model_args.num_prefix_tokens, |
| 286 | + num_attention_heads=model.config.n_head, |
| 287 | + num_hidden_layers=model.config.n_layer, |
| 288 | + hidden_size=model.config.hidden_size, |
| 289 | + prefix_projection=model_args.prefix_projection, |
| 290 | + prefix_projection_hidden_size=model.config.hidden_size, |
| 291 | + dtype=dtype, |
| 292 | + ) |
| 293 | + model = PrefixModelForCausalLM( |
| 294 | + model=model, |
| 295 | + prefix_config=prefix_config, |
| 296 | + postprocess_past_key_value=llama_postprocess_past_key_value, |
| 297 | + ) |
| 298 | + model.mark_only_prefix_as_trainable() |
| 299 | + model.print_trainable_parameters() |
| 300 | + |
| 301 | + tokenizer = AutoTokenizer.from_pretrained( |
| 302 | + model_args.model_name_or_path, |
| 303 | + padding_side="left", # Allow batch inference |
| 304 | + ) |
| 305 | + tokenizer.pad_token = tokenizer.unk_token |
| 306 | + |
| 307 | + # Load the dataset. |
| 308 | + train_ds = load_dataset(read_local_dataset, path="./data/train.txt", lazy=False) |
| 309 | + training_args.do_eval = False |
| 310 | + data_args.always_pad_to_max_length = True |
| 311 | + trans_func = partial(custom_instruction_convert_example, tokenizer=tokenizer, data_args=data_args) |
| 312 | + |
| 313 | + train_ds = train_ds.map(partial(trans_func)) |
| 314 | + |
| 315 | + model_max_length = 512 |
| 316 | + collate_fn = DataCollatorForSeq2Seq( |
| 317 | + return_tensors="pd", |
| 318 | + tokenizer=tokenizer, |
| 319 | + max_length=model_max_length if data_args.always_pad_to_max_length else -1, |
| 320 | + padding="max_length" if data_args.always_pad_to_max_length else True, |
| 321 | + max_label_length=model_max_length if data_args.always_pad_to_max_length else None, |
| 322 | + return_attention_mask=True, |
| 323 | + ) |
| 324 | + |
| 325 | + def compute_metrics_trainer(eval_preds, tokenizer): |
| 326 | + all_preds = [] |
| 327 | + all_labels = [] |
| 328 | + preds = eval_preds.predictions |
| 329 | + preds = [x[x != -100] for x in preds] |
| 330 | + all_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=False)) |
| 331 | + labels = [x[x != -100] for x in eval_preds.label_ids] |
| 332 | + all_labels.extend(tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)) |
| 333 | + |
| 334 | + all_preds = [pred.strip() for pred in all_preds] |
| 335 | + all_labels = [label.strip() for label in all_labels] |
| 336 | + all_preds = [pred.strip("question:") for pred in all_preds] |
| 337 | + all_labels = [label.strip("question:") for label in all_labels] |
| 338 | + |
| 339 | + eval_result = compute_metrics(all_preds, all_labels) |
| 340 | + return eval_result |
| 341 | + |
| 342 | + compute_metrics_func = partial( |
| 343 | + compute_metrics_trainer, |
| 344 | + tokenizer=tokenizer, |
| 345 | + ) |
| 346 | + |
| 347 | + trainer = LlamaTrainer( |
| 348 | + model=model, |
| 349 | + args=training_args, |
| 350 | + train_dataset=train_ds if training_args.do_train else None, |
| 351 | + tokenizer=tokenizer, |
| 352 | + compute_metrics=compute_metrics_func |
| 353 | + if model_args.eval_with_do_generation |
| 354 | + else compute_metrics_not_do_generation, |
| 355 | + do_generation=model_args.eval_with_do_generation, |
| 356 | + data_collator=collate_fn, |
| 357 | + ) |
| 358 | + |
| 359 | + if training_args.do_train: |
| 360 | + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) |
| 361 | + trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1) |
| 362 | + trainer.log_metrics("train", train_result.metrics) |
| 363 | + trainer.save_metrics("train", train_result.metrics) |
| 364 | + trainer.save_state() |
| 365 | + |
| 366 | + if training_args.do_eval: |
| 367 | + eval_result = trainer.evaluate() |
| 368 | + trainer.log_metrics("test", eval_result) |
| 369 | + |
| 370 | + |
| 371 | +if __name__ == "__main__": |
| 372 | + main() |
0 commit comments