From 77d3821baf168ce5240941cf8e426bb43d2443d8 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Mon, 20 Mar 2023 18:23:15 +0800 Subject: [PATCH 01/13] electra_modelzoo --- model_zoo/bert/run_pretrain_trainer.py | 1 - model_zoo/electra/deploy/python/predict.py | 18 +- model_zoo/electra/export_model.py | 62 +-- model_zoo/electra/get_ft_model.py | 43 ++- model_zoo/electra/run_glue.py | 306 +++++---------- model_zoo/electra/run_pretrain.py | 415 +++++++-------------- scripts/regression/ci_case.sh | 5 +- tests/fixtures/model_zoo/electra.yaml | 68 ++++ tests/model_zoo/test_electra.py | 110 ++++++ 9 files changed, 473 insertions(+), 555 deletions(-) create mode 100644 tests/fixtures/model_zoo/electra.yaml create mode 100644 tests/model_zoo/test_electra.py diff --git a/model_zoo/bert/run_pretrain_trainer.py b/model_zoo/bert/run_pretrain_trainer.py index cce5d85ca68d..4d5e99603a5a 100644 --- a/model_zoo/bert/run_pretrain_trainer.py +++ b/model_zoo/bert/run_pretrain_trainer.py @@ -194,7 +194,6 @@ def __getitem__(self, index): # softmax_with_cross_entropy enforce last dim size equal 1 masked_lm_labels = np.expand_dims(masked_lm_labels, axis=-1) next_sentence_labels = np.expand_dims(next_sentence_labels, axis=-1) - return [input_ids, segment_ids, input_mask, masked_lm_positions, masked_lm_labels, next_sentence_labels] diff --git a/model_zoo/electra/deploy/python/predict.py b/model_zoo/electra/deploy/python/predict.py index 160f5a5b52a1..34a739ddd59c 100755 --- a/model_zoo/electra/deploy/python/predict.py +++ b/model_zoo/electra/deploy/python/predict.py @@ -13,12 +13,14 @@ # limitations under the License. import argparse -import time -import numpy as np import os +import time +import numpy as np from paddle import inference + from paddlenlp.transformers import ElectraTokenizer +from paddlenlp.utils.log import logger def parse_args(): @@ -111,7 +113,7 @@ def get_predicted_input(predicted_data, tokenizer, max_seq_length, batch_size): return sen_ids_batch, sen_words_batch -def predict(args, sentences=[], paths=[]): +def predict(): """ Args: sentences (list[str]): each string is a sentence. If sentences not paths @@ -119,7 +121,9 @@ def predict(args, sentences=[], paths=[]): Returns: res (list(numpy.ndarray)): The result of sentence, indicate whether each word is replaced, same shape with sentences. """ - + args = parse_args() + sentences = args.predict_sentences + paths = args.predict_file # initialize data if sentences != [] and isinstance(sentences, list) and (paths == [] or paths is None): predicted_data = sentences @@ -184,9 +188,7 @@ def predict(args, sentences=[], paths=[]): if __name__ == "__main__": - args = parse_args() - sentences = args.predict_sentences - paths = args.predict_file + # sentences = ["The quick brown fox see over the lazy dog.", "The quick brown fox jump over tree lazy dog."] # paths = ["../../debug/test.txt", "../../debug/test.txt.1"] - predict(args, sentences, paths) + predict() diff --git a/model_zoo/electra/export_model.py b/model_zoo/electra/export_model.py index d4ddb8b3bca4..4b570013a20f 100644 --- a/model_zoo/electra/export_model.py +++ b/model_zoo/electra/export_model.py @@ -12,21 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # from collections import namedtuple -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -import os -import hashlib import argparse +import hashlib import json +import os import paddle -import paddle.nn as nn from paddle.static import InputSpec -from paddlenlp.transformers import ElectraForTotalPretraining, ElectraDiscriminator, ElectraGenerator, ElectraModel -from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer +from paddlenlp.transformers import ElectraForSequenceClassification def get_md5sum(file_path): @@ -40,12 +36,42 @@ def get_md5sum(file_path): return md5sum +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--input_model_dir", + required=True, + type=str, + default=None, + help="Directory for storing Electra pretraining model", + ) + parser.add_argument( + "--output_model_dir", + required=True, + default=None, + type=str, + help="Directory for output Electra inference model", + ) + parser.add_argument( + "--model_name", + default="electra-deploy", + required=True, + type=str, + help="prefix name of output model and parameters", + ) + args = parser.parse_args() + return args + + def main(): + args = parse_args() # check and load config - with open(os.path.join(args.input_model_dir, "model_config.json"), "r") as f: + with open(os.path.join(args.input_model_dir, "config.json"), "r") as f: config_dict = json.load(f) - num_classes = config_dict["num_classes"] - if num_classes is None or num_classes <= 0: + num_choices = config_dict["num_choices"] + if num_choices is None or num_choices <= 0: print("%s/model_config.json may not be right, please check" % args.input_model_dir) exit(1) @@ -53,7 +79,6 @@ def main(): input_model_file = os.path.join(args.input_model_dir, "model_state.pdparams") print("load model to get static model : %s \nmodel md5sum : %s" % (input_model_file, get_md5sum(input_model_file))) model_state_dict = paddle.load(input_model_file) - if all((s.startswith("generator") or s.startswith("discriminator")) for s in model_state_dict.keys()): print("the model : %s is electra pretrain model, we need fine-tuning model to deploy" % input_model_file) exit(1) @@ -62,7 +87,7 @@ def main(): exit(1) elif "classifier.dense.weight" in model_state_dict: print("we are load glue fine-tuning model") - model = ElectraForSequenceClassification.from_pretrained(args.input_model_dir, num_classes=num_classes) + model = ElectraForSequenceClassification.from_pretrained(args.input_model_dir, num_classes=num_choices) print("total model layers : ", len(model_state_dict)) else: print("the model file : %s may not be fine-tuning model, please check" % input_model_file) @@ -78,15 +103,4 @@ def main(): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--input_model_dir", required=True, default=None, help="Directory for storing Electra pretraining model" - ) - parser.add_argument( - "--output_model_dir", required=True, default=None, help="Directory for output Electra inference model" - ) - parser.add_argument( - "--model_name", default="electra-deploy", type=str, help="prefix name of output model and parameters" - ) - args, unparsed = parser.parse_known_args() main() diff --git a/model_zoo/electra/get_ft_model.py b/model_zoo/electra/get_ft_model.py index 7597b147e533..055995fce33b 100644 --- a/model_zoo/electra/get_ft_model.py +++ b/model_zoo/electra/get_ft_model.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # from collections import namedtuple -import os -import hashlib import argparse +import hashlib +import os import paddle -import paddle.nn as nn # from paddlenlp.transformers import ElectraForTotalPretraining, ElectraDiscriminator, ElectraGenerator, ElectraModel # from paddlenlp.transformers import ElectraTokenizer @@ -36,7 +35,27 @@ def get_md5sum(file_path): return md5sum -def main(args): +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_dir", required=True, default=None, help="Directory of storing ElectraForTotalPreTraining model" + ) + parser.add_argument( + "--generator_output_file", default="generator_for_ft.pdparams", help="Electra generator model for fine-tuning" + ) + parser.add_argument( + "--discriminator_output_file", + default="discriminator_for_ft.pdparams", + help="Electra discriminator model for fine-tuning", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() pretraining_model = os.path.join(args.model_dir, "model_state.pdparams") if os.path.islink(pretraining_model): print("%s already contain fine-tuning model, pleace check" % args.model_dir) @@ -49,7 +68,6 @@ def main(args): total_pretraining_model = paddle.load(pretraining_model) generator_state_dict = {} discriminator_state_dict = {} - total_keys = [] num_keys = 0 for key in total_pretraining_model.keys(): new_key = None @@ -73,17 +91,4 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_dir", required=True, default=None, help="Directory of storing ElectraForTotalPreTraining model" - ) - parser.add_argument( - "--generator_output_file", default="generator_for_ft.pdparams", help="Electra generator model for fine-tuning" - ) - parser.add_argument( - "--discriminator_output_file", - default="discriminator_for_ft.pdparams", - help="Electra discriminator model for fine-tuning", - ) - args, unparsed = parser.parse_known_args() - main(args) + main() diff --git a/model_zoo/electra/run_glue.py b/model_zoo/electra/run_glue.py index c5a8051a4a02..9a7e1abdade9 100644 --- a/model_zoo/electra/run_glue.py +++ b/model_zoo/electra/run_glue.py @@ -12,11 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse import logging -import os -import random -import time +from dataclasses import dataclass, field from functools import partial import numpy as np @@ -24,9 +21,10 @@ from paddle.io import DataLoader from paddle.metric import Accuracy -from paddlenlp.data import Pad, Stack, Tuple +from paddlenlp.data import DataCollatorWithPadding from paddlenlp.datasets import load_dataset from paddlenlp.metrics import AccuracyAndF1, Mcc, PearsonAndSpearman +from paddlenlp.trainer import PdArgumentParser, Trainer, TrainingArguments from paddlenlp.transformers import ( BertForSequenceClassification, BertTokenizer, @@ -59,100 +57,27 @@ } -def parse_args(): - parser = argparse.ArgumentParser() - - # Required parameters - parser.add_argument( - "--task_name", - default=None, - type=str, - required=True, - help="The name of the task to train selected in the list: " + ", ".join(METRIC_CLASSES.keys()), - ) - parser.add_argument( - "--model_type", +@dataclass +class ModelArguments: + task_name: str = field( default=None, - type=str, - required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + metadata={"help": "The namve of the task to train selected in the list: " + ", ".join(METRIC_CLASSES.keys())}, ) - parser.add_argument( - "--model_name_or_path", + model_type: str = field( default=None, - type=str, - required=True, - help="Path to pre-trained model or shortcut name selected in the list: " - + ", ".join( - sum([list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values()], []) - ), + metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}, ) - parser.add_argument( - "--output_dir", + model_name_or_path: str = field( default=None, - type=str, - required=True, - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument( - "--max_seq_length", - default=128, - type=int, - help="The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded.", - ) - parser.add_argument("--learning_rate", default=1e-4, type=float, help="The initial learning rate for Adam.") - parser.add_argument( - "--num_train_epochs", - default=3, - type=int, - help="Total number of training epochs to perform.", - ) - parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.") - parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every X updates steps.") - parser.add_argument( - "--batch_size", - default=32, - type=int, - help="Batch size per GPU/CPU for training.", - ) - parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") - parser.add_argument( - "--warmup_steps", - default=0, - type=int, - help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion", + metadata={ + "help": "Path to pre-trained model or shortcut name selected in the list: " + + ", ".join(list(ElectraForSequenceClassification.pretrained_init_configuration.keys())) + }, ) - parser.add_argument( - "--warmup_proportion", default=0.1, type=float, help="Linear warmup proportion over total steps." + max_seq_length: int = field( + default=128, metadata={"help": "The maximum total input sequence length after tokenization"} ) - parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.") - parser.add_argument( - "--max_steps", - default=-1, - type=int, - help="If > 0: set total number of training steps to perform. Override num_train_epochs.", - ) - parser.add_argument("--seed", default=42, type=int, help="random seed for initialization") - parser.add_argument( - "--device", - default="gpu", - type=str, - choices=["cpu", "gpu", "npu"], - help="The device to select to train the model, is must be cpu/gpu/npu.", - ) - args = parser.parse_args() - return args - - -def set_seed(args): - # Use the same data seed(for data shuffle) for all procs to guarantee data - # consistency after sharding. - random.seed(args.seed) - np.random.seed(args.seed) - # Maybe different op seeds(for dropout) for different procs is better. By: - # `paddle.seed(args.seed + paddle.distributed.get_rank())` - paddle.seed(args.seed) + warmup_proportion: float = field(default=0.1, metadata={"help": "Linear warmup proportion over total steps."}) @paddle.no_grad() @@ -202,86 +127,60 @@ def convert_example(example, tokenizer, label_list, max_seq_length=512, is_test= label = np.array([label], dtype=label_dtype) # Convert raw text to feature if (int(is_test) + len(example)) == 2: - example = tokenizer(example["sentence"], max_seq_len=max_seq_length) + example = tokenizer(example["sentence"], padding="max_length", max_seq_len=max_seq_length) else: - example = tokenizer(example["sentence1"], text_pair=example["sentence2"], max_seq_len=max_seq_length) + example = tokenizer( + example["sentence1"], text_pair=example["sentence2"], max_seq_len=max_seq_length, padding=True + ) if not is_test: - return example["input_ids"], example["token_type_ids"], label - else: - return example["input_ids"], example["token_type_ids"] + example["labels"] = label + return example -def do_train(args): - paddle.set_device(args.device) - if paddle.distributed.get_world_size() > 1: - paddle.distributed.init_parallel_env() - set_seed(args) +def do_train(): - args.task_name = args.task_name.lower() - metric_class = METRIC_CLASSES[args.task_name] - args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + training_args, model_args = PdArgumentParser([TrainingArguments, ModelArguments]).parse_args_into_dataclasses() + training_args: TrainingArguments = training_args + model_args: ModelArguments = model_args - train_ds = load_dataset("glue", args.task_name, splits="train") - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + model_args.task_name = model_args.task_name.lower() + + train_ds = load_dataset("glue", model_args.task_name, splits="train") + tokenizer = ElectraTokenizer.from_pretrained(model_args.model_name_or_path) trans_func = partial( - convert_example, tokenizer=tokenizer, label_list=train_ds.label_list, max_seq_length=args.max_seq_length + convert_example, tokenizer=tokenizer, label_list=train_ds.label_list, max_seq_length=model_args.max_seq_length ) train_ds = train_ds.map(trans_func, lazy=True) - train_batch_sampler = paddle.io.DistributedBatchSampler(train_ds, batch_size=args.batch_size, shuffle=True) - batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.pad_token_id), # input - Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment - Stack(dtype="int64" if train_ds.label_list else "float32"), # label - ): fn(samples) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=training_args.train_batch_size, shuffle=True + ) train_data_loader = DataLoader( - dataset=train_ds, batch_sampler=train_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True + dataset=train_ds, batch_sampler=train_batch_sampler, num_workers=0, return_list=True ) - if args.task_name == "mnli": - dev_ds_matched, dev_ds_mismatched = load_dataset( - "glue", args.task_name, splits=["dev_matched", "dev_mismatched"] - ) - - dev_ds_matched = dev_ds_matched.map(trans_func, lazy=True) - dev_ds_mismatched = dev_ds_mismatched.map(trans_func, lazy=True) - dev_batch_sampler_matched = paddle.io.BatchSampler(dev_ds_matched, batch_size=args.batch_size, shuffle=False) - dev_data_loader_matched = DataLoader( - dataset=dev_ds_matched, - batch_sampler=dev_batch_sampler_matched, - collate_fn=batchify_fn, - num_workers=0, - return_list=True, - ) - dev_batch_sampler_mismatched = paddle.io.BatchSampler( - dev_ds_mismatched, batch_size=args.batch_size, shuffle=False - ) - dev_data_loader_mismatched = DataLoader( - dataset=dev_ds_mismatched, - batch_sampler=dev_batch_sampler_mismatched, - collate_fn=batchify_fn, - num_workers=0, - return_list=True, - ) + if model_args.task_name == "mnli": + dev_ds = load_dataset("glue", model_args.task_name, splits=["dev_matched"]) else: - dev_ds = load_dataset("glue", args.task_name, splits="dev") - dev_ds = dev_ds.map(trans_func, lazy=True) - dev_batch_sampler = paddle.io.BatchSampler(dev_ds, batch_size=args.batch_size, shuffle=False) - dev_data_loader = DataLoader( - dataset=dev_ds, batch_sampler=dev_batch_sampler, collate_fn=batchify_fn, num_workers=0, return_list=True - ) + dev_ds = load_dataset("glue", model_args.task_name, splits="dev") + + dev_ds = dev_ds.map(trans_func, lazy=True) num_classes = 1 if train_ds.label_list is None else len(train_ds.label_list) - model = model_class.from_pretrained(args.model_name_or_path, num_classes=num_classes) + model = ElectraForSequenceClassification.from_pretrained(model_args.model_name_or_path, num_labels=num_classes) + if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) - num_training_steps = args.max_steps if args.max_steps > 0 else (len(train_data_loader) * args.num_train_epochs) - warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + num_training_steps = ( + training_args.max_steps + if training_args.max_steps > 0 + else (len(train_data_loader) * training_args.num_train_epochs) + ) + warmup = training_args.warmup_steps if training_args.warmup_steps > 0 else training_args.warmup_ratio - lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, warmup) + lr_scheduler = LinearDecayWithWarmup(training_args.learning_rate, num_training_steps, warmup) # Generate parameter names needed to perform weight decay. # All bias and LayerNorm parameters are excluded. @@ -290,80 +189,49 @@ def do_train(args): learning_rate=lr_scheduler, beta1=0.9, beta2=0.999, - epsilon=args.adam_epsilon, + epsilon=training_args.adam_epsilon, parameters=model.parameters(), - weight_decay=args.weight_decay, + weight_decay=training_args.weight_decay, apply_decay_param_fun=lambda x: x in decay_params, ) loss_fct = paddle.nn.loss.CrossEntropyLoss() if train_ds.label_list else paddle.nn.loss.MSELoss() - metric = metric_class() - - global_step = 0 - tic_train = time.time() - for epoch in range(args.num_train_epochs): - for step, batch in enumerate(train_data_loader): - global_step += 1 - - input_ids, segment_ids, labels = batch - logits = model(input_ids, segment_ids) - loss = loss_fct(logits, labels) - loss.backward() - optimizer.step() - lr_scheduler.step() - optimizer.clear_grad() - if global_step % args.logging_steps == 0 or global_step == num_training_steps: - print( - "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" - % ( - global_step, - num_training_steps, - epoch, - step, - paddle.distributed.get_rank(), - loss, - optimizer.get_lr(), - args.logging_steps / (time.time() - tic_train), - ) - ) - tic_train = time.time() - if global_step % args.save_steps == 0 or global_step == num_training_steps: - tic_eval = time.time() - if args.task_name == "mnli": - evaluate(model, loss_fct, metric, dev_data_loader_matched) - evaluate(model, loss_fct, metric, dev_data_loader_mismatched) - print("eval done total : %s s" % (time.time() - tic_eval)) - else: - evaluate(model, loss_fct, metric, dev_data_loader) - print("eval done total : %s s" % (time.time() - tic_eval)) - if paddle.distributed.get_rank() == 0: - output_dir = os.path.join( - args.output_dir, "%s_ft_model_%d.pdparams" % (args.task_name, global_step) - ) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - # Need better way to get inner model of DataParallel - model_to_save = model._layers if isinstance(model, paddle.DataParallel) else model - model_to_save.save_pretrained(output_dir) - tokenizer.save_pretrained(output_dir) - if global_step >= num_training_steps: - return - - -def print_arguments(args): - """print arguments""" - print("----------- Configuration Arguments -----------") - for arg, value in sorted(vars(args).items()): - print("%s: %s" % (arg, value)) - print("------------------------------------------------") + def compute_metrics(p): + # Define the metrics of tasks. + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + + preds = paddle.to_tensor(preds) + label = paddle.to_tensor(p.label_ids) + + metric = Accuracy() + metric.reset() + result = metric.compute(preds, label) + metric.update(result) + accu = metric.accumulate() + metric.reset() + return {"accuracy": accu} + + # TODO: use amp + trainer = Trainer( + model=model, + args=training_args, + data_collator=DataCollatorWithPadding(tokenizer=tokenizer, padding=True, max_length=model_args.max_seq_length), + criterion=loss_fct, + train_dataset=train_ds, + eval_dataset=dev_ds, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + optimizers=[optimizer, lr_scheduler], + ) + + if training_args.do_train: + train_result = trainer.train() + metrics = train_result.metrics + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_state() if __name__ == "__main__": - args = parse_args() - print_arguments(args) - n_gpu = len(os.getenv("CUDA_VISIBLE_DEVICES", "").split(",")) - if args.device in "gpu" and n_gpu > 1: - paddle.distributed.spawn(do_train, args=(args,), nprocs=n_gpu) - else: - do_train(args) + do_train() diff --git a/model_zoo/electra/run_pretrain.py b/model_zoo/electra/run_pretrain.py index 02e12eaf88f9..cde45ec4c4e6 100644 --- a/model_zoo/electra/run_pretrain.py +++ b/model_zoo/electra/run_pretrain.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import copy import io -import json import logging import os import random import time +from dataclasses import dataclass, field import numpy as np import paddle +from paddlenlp.trainer import PdArgumentParser, Trainer, TrainingArguments from paddlenlp.transformers import ( ElectraForTotalPretraining, ElectraPretrainingCriterion, @@ -40,99 +39,57 @@ } -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_type", - default="electra", - type=str, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), +@dataclass +class TrainingArguments(TrainingArguments): + + # per_device_train_batch_size + @property + def micro_batch_size(self): + return self.per_device_train_batch_size + + @property + def eval_freq(self): + return self.eval_steps + + +@dataclass +class ModelArguments: + model_type: str = field( + default="electra", metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())} ) - parser.add_argument( - "--model_name_or_path", + model_name_or_path: str = field( default="electra-small", - type=str, - help="Path to pre-trained model or shortcut name selected in the list: " - + ", ".join( - sum([list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values()], []) - ), - ) - parser.add_argument( - "--input_dir", - default=None, - type=str, - required=True, - help="The input directory where the data will be read from.", - ) - parser.add_argument( - "--output_dir", - default=None, - type=str, - required=True, - help="The output directory where the model predictions and checkpoints will be written.", - ) - parser.add_argument("--max_seq_length", default=128, type=int, help="max length of each sequence") - parser.add_argument("--mask_prob", default=0.15, type=float, help="the probability of one word to be mask") - parser.add_argument( - "--train_batch_size", - default=96, - type=int, - help="Batch size per GPU/CPU for training.", - ) - parser.add_argument( - "--eval_batch_size", - default=96, - type=int, - help="Batch size per GPU/CPU for training.", - ) - parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.") - parser.add_argument("--weight_decay", default=0.01, type=float, help="Weight decay if we apply some.") - parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.") - parser.add_argument( - "--num_train_epochs", - default=4, - type=int, - help="Total number of training epochs to perform.", - ) - parser.add_argument( - "--max_steps", - default=-1, - type=int, - help="If > 0: set total number of training steps to perform. Override num_train_epochs.", - ) - parser.add_argument("--warmup_steps", default=10000, type=int, help="Linear warmup over warmup_steps.") - - parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.") - parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint every X updates steps.") - parser.add_argument( - "--init_from_ckpt", - action="store_true", - help="Whether to load model checkpoint. if True, args.model_name_or_path must be dir store ckpt or will train from fresh start", + metadata={ + "help": "Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([list(classes[-1].pretrained_init_configuration.keys()) for classes in MODEL_CLASSES.values()], []) + ) + }, ) - parser.add_argument( - "--use_amp", action="store_true", help="Whether to use float16(Automatic Mixed Precision) to train." + max_seq_length: int = field(default=128, metadata={"help": "max length of each sequence"}) + mask_prob: float = field(default=0.15, metadata={"help": "the probability of one word to be mask"}) + eager_run: bool = field(default=True, metadata={"help": "Use dygraph mode."}) + init_from_ckpt: bool = field( + default=True, + metadata={ + "help": "Whether to load model checkpoint. if True, args.model_name_or_path must be dir store ckpt or will train from fresh start" + }, ) - parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") - parser.add_argument("--eager_run", type=bool, default=True, help="Use dygraph mode.") - parser.add_argument( - "--device", - default="gpu", - type=str, - choices=["cpu", "gpu"], - help="The device to select to train the model, is must be cpu/gpu.", + max_predictions_per_seq: int = field( + default=20, metadata={"help": "The maximum total input sequence length after tokenization"} ) - args = parser.parse_args() - return args -def set_seed(args): - # Use the same data seed(for data shuffle) for all procs to guarantee data - # consistency after sharding. - random.seed(args.seed) - np.random.seed(args.seed) - # Maybe different op seeds(for dropout) for different procs is better. By: - # `paddle.seed(args.seed + paddle.distributed.get_rank())` - paddle.seed(args.seed) +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field(default=None, metadata={"help": "The input directory where the data will be read from."}) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) class WorkerInitObj(object): @@ -227,7 +184,11 @@ def __init__(self, tokenizer, max_seq_length, mlm=True, mlm_probability=0.15): def __call__(self, examples): if self.mlm: inputs, raw_inputs, labels = self.mask_tokens(examples) - return inputs, raw_inputs, labels + return { + "input_ids": inputs, + "raw_input_ids": raw_inputs, + "generator_labels": labels, + } else: raw_inputs, _ = self.add_special_tokens_and_set_maskprob(examples, True, self.max_seq_length) raw_inputs = self.tensorize_batch(raw_inputs, "int64") @@ -236,7 +197,11 @@ def __call__(self, examples): if self.tokenizer.pad_token is not None: pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) labels[labels == pad_token_id] = -100 - return batch, raw_inputs, labels # noqa:821 + return { + "raw_input_ids": raw_inputs, + "generator_labels": labels, + } + # return batch, raw_inputs, labels # noqa:821 def tensorize_batch(self, examples, dtype): if isinstance(examples[0], (list, tuple)): @@ -303,6 +268,43 @@ def mask_tokens(self, examples): return inputs, raw_inputs, labels +class new_Trainer(Trainer): + def __init__( + self, + model=None, + criterion=None, + args=None, + data_collator=None, + train_dataset=None, + eval_dataset=None, + tokenizer=None, + compute_metrics=None, + callbacks=None, + optimizers=(None, None), + preprocess_logits_for_metrics=None, + ): + super(new_Trainer, self).__init__( + model, + criterion, + args, + data_collator, + train_dataset, + eval_dataset, + tokenizer, + compute_metrics, + callbacks, + optimizers, + preprocess_logits_for_metrics, + ) + + def compute_loss(self, model, inputs, return_outputs=False): + + gen_logits, disc_logits, disc_labels, attention_mask = model(**inputs) + gen_labels = inputs["generator_labels"] + loss = self.criterion(gen_logits, disc_logits, gen_labels, disc_labels, attention_mask) + return loss + + def create_dataloader(dataset, mode="train", batch_size=1, use_gpu=True, data_collator=None): """ Creats dataloader. @@ -336,51 +338,27 @@ def create_dataloader(dataset, mode="train", batch_size=1, use_gpu=True, data_co return dataloader -def do_train(args): - paddle.enable_static() if not args.eager_run else None - paddle.set_device(args.device) +def do_train(): + data_args, training_args, model_args = PdArgumentParser( + [DataArguments, TrainingArguments, ModelArguments] + ).parse_args_into_dataclasses() + training_args: TrainingArguments = training_args + model_args: ModelArguments = model_args + data_args: DataArguments = data_args + + paddle.enable_static() if not model_args.eager_run else None + paddle.set_device(training_args.device) if paddle.distributed.get_world_size() > 1: paddle.distributed.init_parallel_env() - set_seed(args) - # worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank()) + model_args.model_type = model_args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type] - args.model_type = args.model_type.lower() - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = model_class.config_class.from_pretrained(model_args.model_name_or_path) - # Loads or initializes a model. - pretrained_models = list(tokenizer_class.pretrained_init_configuration.keys()) - config = model_class.config_class.from_pretrained(args.model_name_or_path) - - if args.model_name_or_path in pretrained_models: - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - model = model_class(config) - args.init_from_ckpt = False - else: - if os.path.isdir(args.model_name_or_path) and args.init_from_ckpt: - # Load checkpoint - tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - with open(os.path.join(args.model_name_or_path, "run_states.json"), "r") as f: - config_dict = json.load(f) - model_name = config_dict["model_name"] - if model_name in pretrained_models: - model = model_class.from_pretrained(args.model_name_or_path) - model.set_state_dict(paddle.load(os.path.join(args.model_name_or_path, "model_state.pdparams"))) - else: - raise ValueError( - "initialize a model from ckpt need model_name " - "in model_config_file. The supported model_name " - "are as follows: {}".format(tokenizer_class.pretrained_init_configuration.keys()) - ) - else: - raise ValueError( - "initialize a model need identifier or the " - "directory of storing model. if use identifier, the supported model " - "identifiers are as follows: {}, if use directory, " - "make sure set init_from_ckpt as True".format(model_class.pretrained_init_configuration.keys()) - ) + tokenizer = tokenizer_class.from_pretrained(model_args.model_name_or_path) - criterion = ElectraPretrainingCriterion(config) + model = model_class(config) if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) @@ -388,169 +366,40 @@ def do_train(args): tic_load_data = time.time() print("start load data : %s" % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) train_dataset = BookCorpus( - data_path=args.input_dir, tokenizer=tokenizer, max_seq_length=args.max_seq_length, mode="train" + data_path=data_args.input_dir, tokenizer=tokenizer, max_seq_length=model_args.max_seq_length, mode="train" ) print("load data done, total : %s s" % (time.time() - tic_load_data)) # Reads data and generates mini-batches. data_collator = DataCollatorForElectra( - tokenizer=tokenizer, max_seq_length=args.max_seq_length, mlm=True, mlm_probability=args.mask_prob + tokenizer=tokenizer, max_seq_length=model_args.max_seq_length, mlm=True, mlm_probability=model_args.mask_prob + ) + criterion = ElectraPretrainingCriterion(config) + + lr_scheduler = LinearDecayWithWarmup( + training_args.learning_rate, training_args.max_steps, training_args.warmup_steps ) - train_data_loader = create_dataloader( - train_dataset, - batch_size=args.train_batch_size, - mode="train", - use_gpu=True if args.device in "gpu" else False, + trainer = new_Trainer( + model=model, + args=training_args, data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=None, + tokenizer=tokenizer, + criterion=criterion, + optimizers=(None, lr_scheduler), ) - num_training_steps = args.max_steps if args.max_steps > 0 else (len(train_data_loader) * args.num_train_epochs) - - lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, args.warmup_steps) - - clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) - # Generate parameter names needed to perform weight decay. - # All bias and LayerNorm parameters are excluded. - decay_params = [p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])] - optimizer = paddle.optimizer.AdamW( - learning_rate=lr_scheduler, - epsilon=args.adam_epsilon, - parameters=model.parameters(), - weight_decay=args.weight_decay, - grad_clip=clip, - apply_decay_param_fun=lambda x: x in decay_params, - ) - if args.use_amp: - scaler = paddle.amp.GradScaler(init_loss_scaling=1024) - - print("start train : %s" % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))) - trained_global_step = global_step = 0 - t_loss = paddle.to_tensor([0.0]) - log_loss = paddle.to_tensor([0.0]) - loss_list = [] - log_list = [] - tic_train = time.time() - if os.path.isdir(args.model_name_or_path) and args.init_from_ckpt: - optimizer.set_state_dict(paddle.load(os.path.join(args.model_name_or_path, "model_state.pdopt"))) - trained_global_step = global_step = config_dict["global_step"] - if trained_global_step < num_training_steps: - print( - "[ start train from checkpoint ] we have already trained %s steps, seeking next step : %s" - % (trained_global_step, trained_global_step + 1) - ) - else: - print( - "[ start train from checkpoint ] we have already trained %s steps, but total training steps is %s, please check configuration !" - % (trained_global_step, num_training_steps) - ) - exit(0) - - for epoch in range(args.num_train_epochs): - for step, batch in enumerate(train_data_loader): - if trained_global_step > 0: - trained_global_step -= 1 - continue - global_step += 1 - input_ids, raw_input_ids, gen_labels = batch - if args.use_amp: - with paddle.amp.auto_cast(): - gen_logits, disc_logits, disc_labels, attention_mask = model( - input_ids=input_ids, raw_input_ids=raw_input_ids, generator_labels=gen_labels - ) - loss = criterion(gen_logits, disc_logits, gen_labels, disc_labels, attention_mask) - scaled = scaler.scale(loss) - scaled.backward() - t_loss += loss.detach() - scaler.minimize(optimizer, scaled) - else: - gen_logits, disc_logits, disc_labels, attention_mask = model( - input_ids=input_ids, raw_input_ids=raw_input_ids, generator_labels=gen_labels - ) - loss = criterion(gen_logits, disc_logits, gen_labels, disc_labels, attention_mask) - loss.backward() - t_loss += loss.detach() - optimizer.step() - lr_scheduler.step() - optimizer.clear_grad() - if global_step % args.logging_steps == 0: - local_loss = (t_loss - log_loss) / args.logging_steps - if paddle.distributed.get_world_size() > 1: - paddle.distributed.all_gather(loss_list, local_loss) - if paddle.distributed.get_rank() == 0: - log_str = ( - "global step {0:d}/{1:d}, epoch: {2:d}, batch: {3:d}, " - "avg_loss: {4:.15f}, lr: {5:.10f}, speed: {6:.2f} s/it" - ).format( - global_step, - num_training_steps, - epoch, - step, - float((paddle.stack(loss_list).sum() / len(loss_list)).numpy()), - optimizer.get_lr(), - (time.time() - tic_train) / args.logging_steps, - ) - print(log_str) - log_list.append(log_str) - loss_list = [] - else: - log_str = ( - "global step {0:d}/{1:d}, epoch: {2:d}, batch: {3:d}, " - "loss: {4:.15f}, lr: {5:.10f}, speed: {6:.2f} s/it" - ).format( - global_step, - num_training_steps, - epoch, - step, - float(local_loss.numpy()), - optimizer.get_lr(), - (time.time() - tic_train) / args.logging_steps, - ) - print(log_str) - log_list.append(log_str) - log_loss = t_loss - tic_train = time.time() - if global_step % args.save_steps == 0: - if paddle.distributed.get_rank() == 0: - output_dir = os.path.join(args.output_dir, "model_%d.pdparams" % global_step) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - model_to_save = model._layers if isinstance(model, paddle.DataParallel) else model - config_to_save = copy.deepcopy(model_to_save.discriminator.electra.config) - config_to_save.to_json_file(os.path.join(output_dir, "model_config.json")) - run_states = { - "model_name": model_name if args.init_from_ckpt else args.model_name_or_path, - "global_step": global_step, - "epoch": epoch, - "step": step, - } - with open(os.path.join(output_dir, "run_states.json"), "w") as f: - json.dump(run_states, f) - paddle.save(model.state_dict(), os.path.join(output_dir, "model_state.pdparams")) - tokenizer.save_pretrained(output_dir) - paddle.save(optimizer.state_dict(), os.path.join(output_dir, "model_state.pdopt")) - if len(log_list) > 0: - with open(os.path.join(output_dir, "train.log"), "w") as f: - for log in log_list: - if len(log.strip()) > 0: - f.write(log.strip() + "\n") - if global_step >= num_training_steps: - return - - -def print_arguments(args): - """print arguments""" - print("----------- Configuration Arguments -----------") - for arg, value in sorted(vars(args).items()): - print("%s: %s" % (arg, value)) - print("------------------------------------------------") + # training + if training_args.do_train: + train_result = trainer.train() + metrics = train_result.metrics + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() if __name__ == "__main__": - args = parse_args() - print_arguments(args) - n_gpu = len(os.getenv("CUDA_VISIBLE_DEVICES", "").split(",")) - if args.device in "gpu" and n_gpu > 1: - paddle.distributed.spawn(do_train, args=(args,), nprocs=n_gpu) - else: - do_train(args) + do_train() diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index e958b25b251c..ecd01c93871b 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -214,7 +214,7 @@ export DATA_DIR=./BookCorpus/ wget -q https://paddle-qa.bj.bcebos.com/paddlenlp/BookCorpus.tar.gz && tar -xzvf BookCorpus.tar.gz time (python -u ./run_pretrain.py \ --model_type electra \ - --model_name_or_path electra-small \ + --model_name_or_path chinese-electra-small \ --input_dir ./BookCorpus/ \ --output_dir ./pretrain_model/ \ --train_batch_size 64 \ @@ -227,6 +227,9 @@ time (python -u ./run_pretrain.py \ --logging_steps 1 \ --save_steps 1 \ --max_steps 1 \ + --do_train true \ + --do_train true \ + --fp16 False \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain } diff --git a/tests/fixtures/model_zoo/electra.yaml b/tests/fixtures/model_zoo/electra.yaml new file mode 100644 index 000000000000..bad98c36cb99 --- /dev/null +++ b/tests/fixtures/model_zoo/electra.yaml @@ -0,0 +1,68 @@ +pretrain: + default: + model_type: electra + model_name_or_path: __internal_testing__/chinese-electra-small + input_dir: model_zoo/electra/data + output_dir: model_zoo/electra/output/pretrained_models + max_predictions_per_seq: 20 + per_device_train_batch_size: 2 + warmup_steps: 2 + num_train_epochs: 0.0005 + logging_steps: 1 + save_steps: 2 + max_steps: 10 + device: gpu + fp16: False + do_train: true + + slow: + model_type: electra + model_name_or_path: chinese-electra-small + input_dir: model_zoo/electra/data + output_dir: model_zoo/electra/output/pretrained_models + per_device_train_batch_size: 32 + learning_rate: 1e-4 + weight_decay: 1e-2 + adam_epsilon: 1e-6 + warmup_steps: 10000 + num_train_epochs: 3 + logging_steps: 1 + save_steps: 20000 + max_steps: 1000000 + device: gpu + fp16: False + do_train: true + + +glue: + default: + model_name_or_path: __internal_testing__/chinese-electra-small + output_dir: model_zoo/electra/tmp + task_name: SST-2 + max_seq_length: 32 + per_device_train_batch_size: 32 + per_device_eval_batch_size: 32 + learning_rate: 2e-5 + num_train_epochs: 0.0001 + logging_steps: 1 + save_steps: 400 + device: cpu + fp16: False + do_train: true + do_eval: true + + slow: + model_name_or_path: chinese-electra-small + output_dir: model_zoo/electra/tmp + task_name: SST-2 + max_seq_length: 128 + per_device_train_batch_size: 32 + per_device_eval_batch_size: 32 + learning_rate: 2e-5 + num_train_epochs: 3 + logging_steps: 1 + save_steps: 500 + device: gpu + fp16: False + do_train: true + do_eval: true \ No newline at end of file diff --git a/tests/model_zoo/test_electra.py b/tests/model_zoo/test_electra.py new file mode 100644 index 000000000000..160b83148dd9 --- /dev/null +++ b/tests/model_zoo/test_electra.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import sys +from unittest import TestCase + +from paddlenlp.utils.downloader import get_path_from_url_with_filelock +from paddlenlp.utils.log import logger +from tests.testing_utils import argv_context_guard, load_test_config + +CUDA_VISIBLE_DEVICES = 1 + + +class ELECTRA_Test(TestCase): + def download_corpus(self, input_dir): + os.makedirs(input_dir, exist_ok=True) + files = [ + "https://paddle-qa.bj.bcebos.com/paddlenlp/BookCorpus.tar.gz", + ] + + for file in files: + file_name = file.split("/")[-1] + file_path = os.path.join(input_dir, file_name) + if not os.path.exists(file_path): + logger.info(f"start to download corpus: <{file_name}> into <{input_dir}>") + get_path_from_url_with_filelock(file, root_dir=input_dir) + + def setUp(self) -> None: + self.path = "./model_zoo/electra" + self.config_path = "./tests/fixtures/model_zoo/electra.yaml" + sys.path.insert(0, self.path) + + def tearDown(self) -> None: + sys.path.remove(self.path) + + def test_pretrain(self): + + # 1. run pretrain + pretrain_config = load_test_config(self.config_path, "pretrain") + self.download_corpus(pretrain_config["input_dir"]) + with argv_context_guard(pretrain_config): + from run_pretrain import do_train + + do_train() + + # 2. get_ft_model + ft_config = { + "model_dir": pretrain_config["output_dir"], + } + with argv_context_guard(ft_config): + from get_ft_model import main + + main() + + # 3. run glue + glue_config = load_test_config(self.config_path, "glue") + glue_config["output_dir"] = "pretrained_model/model" + glue_config["model_name_or_path"] = pretrain_config["output_dir"] + with argv_context_guard(glue_config): + from run_glue import do_train + + do_train() + + # 4. export model + export_config = { + "model_name": pretrain_config["model_name_or_path"], + "output_model_dir": "infer_model/model", + "input_model_dir": glue_config["output_dir"], + } + with argv_context_guard(export_config): + from export_model import main + + main() + + # infer model of samples + infer_config = { + "model_file": "infer_model/model/__internal_testing__/chinese-electra-small.pdmodel", + "params_file": "infer_model/model/__internal_testing__/chinese-electra-small.pdiparams", + "predict_sentences": "uneasy mishmash of styles and genres ." + "director rob marshall went out gunning to make a great one .", + "batch_size": 2, + "max_seq_length": 128, + "model_name": pretrain_config["model_name_or_path"], + } + with argv_context_guard(infer_config): + from deploy.python.predict import predict + + predict() + + def test_glue(self): + + glue_config = load_test_config(self.config_path, "glue") + with argv_context_guard(glue_config): + from run_glue import do_train + + do_train() From 7fe65e8e28fb17ac5868e2aa86ec7d9f236eafa3 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Mon, 20 Mar 2023 18:31:42 +0800 Subject: [PATCH 02/13] modefied --- model_zoo/bert/run_pretrain_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/model_zoo/bert/run_pretrain_trainer.py b/model_zoo/bert/run_pretrain_trainer.py index 4d5e99603a5a..cce5d85ca68d 100644 --- a/model_zoo/bert/run_pretrain_trainer.py +++ b/model_zoo/bert/run_pretrain_trainer.py @@ -194,6 +194,7 @@ def __getitem__(self, index): # softmax_with_cross_entropy enforce last dim size equal 1 masked_lm_labels = np.expand_dims(masked_lm_labels, axis=-1) next_sentence_labels = np.expand_dims(next_sentence_labels, axis=-1) + return [input_ids, segment_ids, input_mask, masked_lm_positions, masked_lm_labels, next_sentence_labels] From 678eafea10b5ef02f31bb88f592b77ea2ddefe81 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Mon, 20 Mar 2023 19:44:25 +0800 Subject: [PATCH 03/13] modefied_ci --- scripts/regression/ci_case.sh | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index ecd01c93871b..7e993d67de58 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -217,21 +217,43 @@ time (python -u ./run_pretrain.py \ --model_name_or_path chinese-electra-small \ --input_dir ./BookCorpus/ \ --output_dir ./pretrain_model/ \ - --train_batch_size 64 \ + --max_predictions_per_seq 20 \ + --per_device_train_batch_size 2 \ --learning_rate 5e-4 \ - --max_seq_length 128 \ --weight_decay 1e-2 \ --adam_epsilon 1e-6 \ --warmup_steps 10000 \ - --num_train_epochs 4 \ --logging_steps 1 \ --save_steps 1 \ --max_steps 1 \ - --do_train true \ - --do_train true \ - --fp16 False \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain +time(python -u ./get_ft_model.py \ + --model_dir ./pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 +print_info $? electra_get_ft_model +time (python -m paddle.distributed.launch run_glue.py \ + --model_type electra \ + --model_name_or_path chinese-electra-small \ + --task_name SST2 \ + --max_seq_length 128 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 32 \ + --learning_rate 1e-4 \ + --num_train_epochs 3 \ + --logging_steps 1 \ + --save_steps 1 \ + --max_steps 1 \ + --output_dir ./tmp/ \ + --device gpu \ + --fp16 False\ + --do_train \ + --do_eval >${log_path}/electra_fintune) >>${log_path}/electra_fintune 2>&1 +print_info $? electra_fintune +time (python -u ./export_model.py \ + --model_name chinese-electra-small \ + --input_model_dir ./tmp/ \ + --output_model_dir ./infer_model/model >${log_path}/electra_export) >>${log_path}/electra_export 2>&1 +print_info $? electra_export } fast_gpt(){ # FT From 2be330bbd256718369793d55687e059a2507535c Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Mon, 20 Mar 2023 20:31:04 +0800 Subject: [PATCH 04/13] modefied_ci --- scripts/regression/ci_case.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index 7e993d67de58..e746de3d0bd8 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -219,11 +219,10 @@ time (python -u ./run_pretrain.py \ --output_dir ./pretrain_model/ \ --max_predictions_per_seq 20 \ --per_device_train_batch_size 2 \ - --learning_rate 5e-4 \ + --per_device_train_batch_size 2 \ --weight_decay 1e-2 \ --adam_epsilon 1e-6 \ --warmup_steps 10000 \ - --logging_steps 1 \ --save_steps 1 \ --max_steps 1 \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 @@ -233,8 +232,8 @@ time(python -u ./get_ft_model.py \ print_info $? electra_get_ft_model time (python -m paddle.distributed.launch run_glue.py \ --model_type electra \ - --model_name_or_path chinese-electra-small \ - --task_name SST2 \ + --model_name_or_path ./pretrain_model/ \ + --task_name SST-2 \ --max_seq_length 128 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 32 \ From 9b5c6c98a339d36760db37572b33ceab448a0e1d Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Mon, 20 Mar 2023 20:54:47 +0800 Subject: [PATCH 05/13] modefied --- model_zoo/electra/run_glue.py | 1 + scripts/regression/ci_case.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/model_zoo/electra/run_glue.py b/model_zoo/electra/run_glue.py index 9a7e1abdade9..175535f8ba38 100644 --- a/model_zoo/electra/run_glue.py +++ b/model_zoo/electra/run_glue.py @@ -171,6 +171,7 @@ def do_train(): model = ElectraForSequenceClassification.from_pretrained(model_args.model_name_or_path, num_labels=num_classes) if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() model = paddle.DataParallel(model) num_training_steps = ( diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index e746de3d0bd8..2625f3d7301d 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -225,6 +225,7 @@ time (python -u ./run_pretrain.py \ --warmup_steps 10000 \ --save_steps 1 \ --max_steps 1 \ + --do_train true \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain time(python -u ./get_ft_model.py \ From 1abe71c8a4894a8908c6af69a5cb00c4d75a13aa Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Mon, 20 Mar 2023 21:20:00 +0800 Subject: [PATCH 06/13] modefied_ci --- scripts/regression/ci_case.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index 2625f3d7301d..addd84311c66 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -228,7 +228,7 @@ time (python -u ./run_pretrain.py \ --do_train true \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain -time(python -u ./get_ft_model.py \ +time (python -u ./get_ft_model.py \ --model_dir ./pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 print_info $? electra_get_ft_model time (python -m paddle.distributed.launch run_glue.py \ From 3d135efea5d8e484597e1d4c5a01d9770f03ef62 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Fri, 7 Apr 2023 17:11:58 +0800 Subject: [PATCH 07/13] modefied_ci --- scripts/regression/ci_case.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index addd84311c66..a5f067094f1e 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -210,7 +210,6 @@ time (python -m paddle.distributed.launch --log_dir log run_pretrain.py --mode electra(){ cd ${nlp_dir}/model_zoo/electra/ export CUDA_VISIBLE_DEVICES=${cudaid2} -export DATA_DIR=./BookCorpus/ wget -q https://paddle-qa.bj.bcebos.com/paddlenlp/BookCorpus.tar.gz && tar -xzvf BookCorpus.tar.gz time (python -u ./run_pretrain.py \ --model_type electra \ @@ -231,7 +230,7 @@ print_info $? electra_pretrain time (python -u ./get_ft_model.py \ --model_dir ./pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 print_info $? electra_get_ft_model -time (python -m paddle.distributed.launch run_glue.py \ +time (python -m paddle.distributed.launch ./run_glue.py \ --model_type electra \ --model_name_or_path ./pretrain_model/ \ --task_name SST-2 \ From 54d5307bf988306cc4405440b9672c8aad7bc318 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Fri, 7 Apr 2023 17:55:47 +0800 Subject: [PATCH 08/13] modefied --- model_zoo/electra/get_ft_model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/model_zoo/electra/get_ft_model.py b/model_zoo/electra/get_ft_model.py index 055995fce33b..b15f80a47fb8 100644 --- a/model_zoo/electra/get_ft_model.py +++ b/model_zoo/electra/get_ft_model.py @@ -18,11 +18,6 @@ import paddle -# from paddlenlp.transformers import ElectraForTotalPretraining, ElectraDiscriminator, ElectraGenerator, ElectraModel -# from paddlenlp.transformers import ElectraTokenizer -# -# MODEL_CLASSES = {"electra": (ElectraForTotalPretraining, ElectraTokenizer), } - def get_md5sum(file_path): md5sum = None From 9bcda263de5172f9bae06b2e7799a864a45db567 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Fri, 7 Apr 2023 18:28:16 +0800 Subject: [PATCH 09/13] modefied --- model_zoo/electra/get_ft_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/model_zoo/electra/get_ft_model.py b/model_zoo/electra/get_ft_model.py index b15f80a47fb8..9dba16d436b3 100644 --- a/model_zoo/electra/get_ft_model.py +++ b/model_zoo/electra/get_ft_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # from collections import namedtuple + import argparse import hashlib import os From 8ba2fd24aa4a116940f6c7f33ac3dafeb9d20a09 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Fri, 7 Apr 2023 18:36:38 +0800 Subject: [PATCH 10/13] modefied --- model_zoo/electra/get_ft_model.py | 1 - scripts/regression/ci_case.sh | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/model_zoo/electra/get_ft_model.py b/model_zoo/electra/get_ft_model.py index 9dba16d436b3..b15f80a47fb8 100644 --- a/model_zoo/electra/get_ft_model.py +++ b/model_zoo/electra/get_ft_model.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # from collections import namedtuple - import argparse import hashlib import os diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index 7a9838cbc2c6..fde519f7bf09 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -211,11 +211,11 @@ electra(){ cd ${nlp_dir}/model_zoo/electra/ export CUDA_VISIBLE_DEVICES=${cudaid2} wget -q https://paddle-qa.bj.bcebos.com/paddlenlp/BookCorpus.tar.gz && tar -xzvf BookCorpus.tar.gz -time (python -u ./run_pretrain.py \ +time (python -m paddle.distributed.launch run_pretrain.py \ --model_type electra \ --model_name_or_path chinese-electra-small \ --input_dir ./BookCorpus/ \ - --output_dir ./pretrain_model/ \ + --output_dir pretrain_model/ \ --max_predictions_per_seq 20 \ --per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \ @@ -227,12 +227,12 @@ time (python -u ./run_pretrain.py \ --do_train true \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain -time (python -u ./get_ft_model.py \ - --model_dir ./pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 +time (python -m paddle.distributed.launch get_ft_model.py \ + --model_dir pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 print_info $? electra_get_ft_model -time (python -m paddle.distributed.launch ./run_glue.py \ +time (python -m paddle.distributed.launch run_glue.py \ --model_type electra \ - --model_name_or_path ./pretrain_model/ \ + --model_name_or_path pretrain_model/ \ --task_name SST-2 \ --max_seq_length 128 \ --per_device_train_batch_size 32 \ From 8e7f94889dc59bd15167c7740a605c9986673b55 Mon Sep 17 00:00:00 2001 From: tsinghua-zhang Date: Fri, 7 Apr 2023 18:39:48 +0800 Subject: [PATCH 11/13] modefied --- scripts/regression/ci_case.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index fde519f7bf09..c1adddfb8f65 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -227,7 +227,7 @@ time (python -m paddle.distributed.launch run_pretrain.py \ --do_train true \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain -time (python -m paddle.distributed.launch get_ft_model.py \ +time (python -u get_ft_model.py \ --model_dir pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 print_info $? electra_get_ft_model time (python -m paddle.distributed.launch run_glue.py \ From 9f326c4a2ac4b6a8813e2b0ce0dd048c2d607214 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Ctsinghua-zhang=E2=80=9D?= Date: Thu, 20 Apr 2023 10:45:40 +0800 Subject: [PATCH 12/13] modefied_ci --- scripts/regression/ci_case.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index c1adddfb8f65..66459c338e42 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -215,7 +215,7 @@ time (python -m paddle.distributed.launch run_pretrain.py \ --model_type electra \ --model_name_or_path chinese-electra-small \ --input_dir ./BookCorpus/ \ - --output_dir pretrain_model/ \ + --output_dir ./pretrain_model/ \ --max_predictions_per_seq 20 \ --per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \ @@ -228,11 +228,11 @@ time (python -m paddle.distributed.launch run_pretrain.py \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain time (python -u get_ft_model.py \ - --model_dir pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 + --model_dir ./pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 print_info $? electra_get_ft_model time (python -m paddle.distributed.launch run_glue.py \ --model_type electra \ - --model_name_or_path pretrain_model/ \ + --model_name_or_path ./pretrain_model/ \ --task_name SST-2 \ --max_seq_length 128 \ --per_device_train_batch_size 32 \ From 6b7a6368b1c5221bbccb0b7d7a905e26e34ae095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Ctsinghua-zhang=E2=80=9D?= Date: Thu, 20 Apr 2023 12:18:12 +0800 Subject: [PATCH 13/13] modefied_ci --- scripts/regression/ci_case.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/regression/ci_case.sh b/scripts/regression/ci_case.sh index b9ca72d85337..fd4f55f6676a 100644 --- a/scripts/regression/ci_case.sh +++ b/scripts/regression/ci_case.sh @@ -228,7 +228,7 @@ time (python -m paddle.distributed.launch run_pretrain.py \ --device gpu >${log_path}/electra_pretrain) >>${log_path}/electra_pretrain 2>&1 print_info $? electra_pretrain time (python -u get_ft_model.py \ - --model_dir ./pretrain_model/ \>${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 + --model_dir ./pretrain_model >${log_path}/electra_get_ft_model) >>${log_path}/electra_get_ft_model 2>&1 print_info $? electra_get_ft_model time (python -m paddle.distributed.launch run_glue.py \ --model_type electra \