Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions model_zoo/electra/deploy/python/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,17 @@ def get_predicted_input(predicted_data, tokenizer, max_seq_length, batch_size):
return sen_ids_batch, sen_words_batch


def predict(args, sentences=[], paths=[]):
def predict():
"""
Args:
sentences (list[str]): each string is a sentence. If sentences not paths
paths (list[str]): The paths of file which contain sentences. If paths not sentences
Returns:
res (list(numpy.ndarray)): The result of sentence, indicate whether each word is replaced, same shape with sentences.
"""

args = parse_args()
sentences = args.predict_sentences
paths = args.predict_file
# initialize data
if sentences != [] and isinstance(sentences, list) and (paths == [] or paths is None):
predicted_data = sentences
Expand Down Expand Up @@ -186,9 +188,7 @@ def predict(args, sentences=[], paths=[]):


if __name__ == "__main__":
args = parse_args()
sentences = args.predict_sentences
paths = args.predict_file

# sentences = ["The quick brown fox see over the lazy dog.", "The quick brown fox jump over tree lazy dog."]
# paths = ["../../debug/test.txt", "../../debug/test.txt.1"]
predict(args, sentences, paths)
predict()
53 changes: 40 additions & 13 deletions model_zoo/electra/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import argparse
import hashlib
import json
import os

import paddle
Expand All @@ -35,12 +36,49 @@ def get_md5sum(file_path):
return md5sum


def parse_args():
parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument(
"--input_model_dir",
required=True,
type=str,
default=None,
help="Directory for storing Electra pretraining model",
)
parser.add_argument(
"--output_model_dir",
required=True,
default=None,
type=str,
help="Directory for output Electra inference model",
)
parser.add_argument(
"--model_name",
default="electra-deploy",
required=True,
type=str,
help="prefix name of output model and parameters",
)
args = parser.parse_args()
return args


def main():
args = parse_args()
# check and load config
with open(os.path.join(args.input_model_dir, "config.json"), "r") as f:
config_dict = json.load(f)
num_choices = config_dict["num_choices"]
if num_choices is None or num_choices <= 0:
print("%s/model_config.json may not be right, please check" % args.input_model_dir)
exit(1)

# check and load model
input_model_file = os.path.join(args.input_model_dir, "model_state.pdparams")
print("load model to get static model : %s \nmodel md5sum : %s" % (input_model_file, get_md5sum(input_model_file)))
model_state_dict = paddle.load(input_model_file)

if all((s.startswith("generator") or s.startswith("discriminator")) for s in model_state_dict.keys()):
print("the model : %s is electra pretrain model, we need fine-tuning model to deploy" % input_model_file)
exit(1)
Expand All @@ -49,7 +87,7 @@ def main():
exit(1)
elif "classifier.dense.weight" in model_state_dict:
print("we are load glue fine-tuning model")
model = ElectraForSequenceClassification.from_pretrained(args.input_model_dir)
model = ElectraForSequenceClassification.from_pretrained(args.input_model_dir, num_classes=num_choices)
print("total model layers : ", len(model_state_dict))
else:
print("the model file : %s may not be fine-tuning model, please check" % input_model_file)
Expand All @@ -65,15 +103,4 @@ def main():


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_model_dir", required=True, default=None, help="Directory for storing Electra pretraining model"
)
parser.add_argument(
"--output_model_dir", required=True, default=None, help="Directory for output Electra inference model"
)
parser.add_argument(
"--model_name", default="electra-deploy", type=str, help="prefix name of output model and parameters"
)
args, unparsed = parser.parse_known_args()
main()
37 changes: 22 additions & 15 deletions model_zoo/electra/get_ft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,27 @@ def get_md5sum(file_path):
return md5sum


def main(args):
def parse_args():
parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument(
"--model_dir", required=True, default=None, help="Directory of storing ElectraForTotalPreTraining model"
)
parser.add_argument(
"--generator_output_file", default="generator_for_ft.pdparams", help="Electra generator model for fine-tuning"
)
parser.add_argument(
"--discriminator_output_file",
default="discriminator_for_ft.pdparams",
help="Electra discriminator model for fine-tuning",
)
args = parser.parse_args()
return args


def main():
args = parse_args()
pretraining_model = os.path.join(args.model_dir, "model_state.pdparams")
if os.path.islink(pretraining_model):
print("%s already contain fine-tuning model, pleace check" % args.model_dir)
Expand Down Expand Up @@ -66,17 +86,4 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir", required=True, default=None, help="Directory of storing ElectraForTotalPreTraining model"
)
parser.add_argument(
"--generator_output_file", default="generator_for_ft.pdparams", help="Electra generator model for fine-tuning"
)
parser.add_argument(
"--discriminator_output_file",
default="discriminator_for_ft.pdparams",
help="Electra discriminator model for fine-tuning",
)
args, unparsed = parser.parse_known_args()
main(args)
main()
Loading