Skip to content

Commit bb76b09

Browse files
authored
[autonlp] add text classifcation taskflow (#4815)
* from train to infer * add taskflow arguments * update * fix * fix * fix * fix
1 parent 2180cae commit bb76b09

File tree

4 files changed

+156
-100
lines changed

4 files changed

+156
-100
lines changed

paddlenlp/experimental/autonlp/auto_trainer_base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import copy
1515
import datetime
1616
import os
17-
import shutil
1817
from abc import ABCMeta, abstractmethod
1918
from typing import Any, Callable, Dict, List, Optional, Union
2019

@@ -49,8 +48,9 @@ class AutoTrainerBase(metaclass=ABCMeta):
4948
use verbosity > 0 to set stop the workers from logging to the driver.
5049
"""
5150

52-
training_path = "training"
53-
export_path = "exported_model"
51+
training_path = "training_checkpoints" # filepath for Trainer's training checkpoints
52+
save_path = "trained_model" # filepath for the trained dygraph model
53+
export_path = "exported_model" # filepath for the exported static model
5454
results_filename = "experiment_results.csv"
5555

5656
def __init__(
@@ -143,10 +143,8 @@ def export(self, export_path, trial_id=None):
143143
export_path (str, required): the filepath to export to
144144
trial_id (int, required): use the `trial_id` to select the model to export. Defaults to the best model selected by `metric_for_best_model`
145145
"""
146-
model_result = self._get_model_result(trial_id=trial_id)
147-
exported_model_path = os.path.join(model_result.log_dir, self.export_path)
148-
shutil.copytree(exported_model_path, export_path)
149-
logger.info(f"Exported to {export_path}")
146+
147+
raise NotImplementedError
150148

151149
@abstractmethod
152150
def to_taskflow(self, trial_id=None):

paddlenlp/experimental/autonlp/text_classification.py

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import copy
1515
import functools
16+
import json
1617
import os
1718
import shutil
1819
from typing import Any, Callable, Dict, List
@@ -40,7 +41,9 @@
4041
AutoModelForSequenceClassification,
4142
AutoTokenizer,
4243
PretrainedTokenizer,
44+
export_model,
4345
)
46+
from paddlenlp.utils.log import logger
4447

4548
from .auto_trainer_base import AutoTrainerBase
4649

@@ -61,9 +64,8 @@ class AutoTrainerForTextClassification(AutoTrainerBase):
6164
language (string, required): language of the text.
6265
output_dir (str, optional): Output directory for the experiments, defaults to "autpnlp_results".
6366
id2label(dict(int,string)): The dictionary to map the predictions from class ids to class names.
64-
verbosity: (int, optional): controls the verbosity of the run. Defaults to 1, which let the workers log to the driver.To reduce the amount of logs,
65-
use verbosity > 0 to set stop the workers from logging to the driver.
66-
67+
multilabel_threshold (float): The probability threshold used for the multi_label setup. Only effective if model = "multi_label". Defaults to 0.5.
68+
verbosity: (int, optional): controls the verbosity of the run. Defaults to 1, which let the workers log to the driver.To reduce the amount of logs, use verbosity > 0 to set stop the workers from logging to the driver.
6769
"""
6870

6971
def __init__(
@@ -88,6 +90,7 @@ def __init__(
8890
self.text_column = text_column
8991
self.label_column = label_column
9092
self.id2label = self.kwargs.get("id2label", None)
93+
self.multilabel_threshold = self.kwargs.get("multilabel_threshold", 0.5)
9194
if problem_type in ["multi_label", "multi_class"]:
9295
self.problem_type = problem_type
9396
else:
@@ -244,8 +247,12 @@ def _data_checks_and_inference(self):
244247
raise ValueError(
245248
f"Label {label} is not found in the user-provided id2label argument: {self.id2label}"
246249
)
250+
id2label_path = os.path.join(self.output_dir, "id2label.json")
251+
with open(id2label_path, "w", encoding="utf-8") as f:
252+
json.dump(self.id2label, f, ensure_ascii=False)
253+
logger.info(f"Exported id2label to {id2label_path}")
247254

248-
def _construct_trainer(self, config) -> Trainer:
255+
def _construct_trainer(self, config, eval_dataset=None) -> Trainer:
249256
if "EarlyStoppingCallback.early_stopping_patience" in config:
250257
callbacks = [
251258
EarlyStoppingCallback(early_stopping_patience=config["EarlyStoppingCallback.early_stopping_patience"])
@@ -262,7 +269,10 @@ def _construct_trainer(self, config) -> Trainer:
262269
max_length=model.config.max_position_embeddings, # truncate to the max length allowed by the model
263270
)
264271
processed_train_dataset = copy.deepcopy(self.train_dataset).map(trans_func, lazy=False)
265-
processed_eval_dataset = copy.deepcopy(self.eval_dataset).map(trans_func, lazy=False)
272+
if eval_dataset is None:
273+
processed_eval_dataset = copy.deepcopy(self.eval_dataset).map(trans_func, lazy=False)
274+
else:
275+
processed_eval_dataset = copy.deepcopy(eval_dataset).map(trans_func, lazy=False)
266276
training_args = self._override_hp(config, self._default_training_argument)
267277
trainer = Trainer(
268278
model=model,
@@ -279,7 +289,11 @@ def _construct_trainer(self, config) -> Trainer:
279289
max_length = config.get("PreprocessArguments.max_length", 128)
280290
tokenizer = AutoTokenizer.from_pretrained(model_path)
281291
processed_train_dataset = copy.deepcopy(self.train_dataset).map(self._preprocess_labels, lazy=False)
282-
processed_eval_dataset = copy.deepcopy(self.eval_dataset).map(self._preprocess_labels, lazy=False)
292+
if eval_dataset is None:
293+
processed_eval_dataset = copy.deepcopy(self.eval_dataset).map(self._preprocess_labels, lazy=False)
294+
else:
295+
processed_eval_dataset = copy.deepcopy(eval_dataset).map(self._preprocess_labels, lazy=False)
296+
283297
model = AutoModelForMaskedLM.from_pretrained(model_path)
284298
template = AutoTemplate.create_from(
285299
prompt=config["template.prompt"],
@@ -323,12 +337,8 @@ def trainable(config):
323337
trainer = self._construct_trainer(config)
324338
trainer.train()
325339
eval_metrics = trainer.evaluate()
326-
if config["trainer_type"] == "PromptTrainer":
327-
# It's difficult to load back the prompt model as a dynamic model due to lack of AutoModel support now
328-
# We directly export a static model instead of a dynamic model
329-
trainer.export_model(self.export_path)
330-
else:
331-
trainer.save_model(self.export_path)
340+
trainer.save_model(self.save_path)
341+
332342
if os.path.exists(self.training_path):
333343
logger.info("Removing training checkpoints to conserve disk space")
334344
shutil.rmtree(self.training_path)
@@ -347,22 +357,13 @@ def evaluate(self, trial_id=None, eval_dataset=None):
347357
"""
348358
model_result = self._get_model_result(trial_id=trial_id)
349359
model_config = model_result.metrics["config"]["candidates"]
350-
if model_config["trainer_type"] == "PromptTrainer":
351-
raise NotImplementedError(
352-
"'PromptTrainer' models do not support 'evaluate' yet because dygraph save model has not been implemented."
353-
)
354-
model_config["TrainingArguments.model_name_or_path"] = os.path.join(model_result.log_dir, self.export_path)
355-
trainer = self._construct_trainer(model_config)
356-
if eval_dataset is not None:
357-
trans_func = functools.partial(
358-
self._preprocess_fn,
359-
tokenizer=trainer.tokenizer,
360-
max_length=trainer.model.config.max_position_embeddings, # truncate to the max length allowed by the model
361-
)
362-
processed_eval_dataset = copy.deepcopy(eval_dataset).map(trans_func, lazy=False)
363-
eval_metrics = trainer.evaluate(processed_eval_dataset)
364-
else:
365-
eval_metrics = trainer.evaluate()
360+
361+
trainer = self._construct_trainer(model_config, eval_dataset)
362+
trainer.load_state_dict_from_checkpoint(
363+
resume_from_checkpoint=os.path.join(model_result.log_dir, self.save_path)
364+
)
365+
366+
eval_metrics = trainer.evaluate()
366367
return eval_metrics
367368

368369
def _compute_metrics(self, eval_preds: EvalPrediction) -> Dict[str, float]:
@@ -386,7 +387,7 @@ def _compute_multi_class_metrics(self, eval_preds: EvalPrediction) -> Dict[str,
386387

387388
def _compute_multi_label_metrics(self, eval_preds: EvalPrediction) -> Dict[str, float]:
388389
pred_probs = sigmoid(eval_preds.predictions)
389-
pred_ids = pred_probs > 0.5
390+
pred_ids = pred_probs > self.multilabel_threshold
390391
metrics = {}
391392
# In multilabel classification, this function computes subset accuracy:
392393
# the set of labels predicted for a sample must exactly match the corresponding set of labels in y_true.
@@ -425,23 +426,84 @@ def _preprocess_fn(
425426
result["labels"] = example_with_labels["labels"]
426427
return result
427428

428-
def to_taskflow(self, trial_id=None):
429+
def to_taskflow(self, trial_id=None, export_path=None, batch_size=1, max_length=512, precision="fp32"):
429430
"""
430431
Convert the model from a certain `trial_id` to a Taskflow for model inference
431432
432433
Args:
433-
trial_id (int, required): use the `trial_id` to select the model to export. Defaults to the best model selected by `metric_for_best_model`
434+
trial_id (int): use the `trial_id` to select the model to export. Defaults to the best model selected by `metric_for_best_model`
435+
export_path (str): the filepath to export to
436+
max_length (int): Maximum number of tokens for the model. Defaults to 512.
437+
batch_size(int): The sample number of a mini-batch. Defaults to 1.
438+
precision (str): Select among ["fp32", "fp16"]. Default to "fp32".
434439
"""
435440
model_result = self._get_model_result(trial_id=trial_id)
436441
model_config = model_result.metrics["config"]["candidates"]
442+
trial_id = model_result.metrics["trial_id"]
443+
444+
if export_path is None:
445+
export_path = os.path.join(self.export_path, trial_id)
446+
447+
self.export(export_path=export_path, trial_id=trial_id)
448+
437449
if model_config["trainer_type"] == "PromptTrainer":
438-
raise NotImplementedError("'Taskflow' inference does not support models trained with PromptTrainer yet.")
450+
mode = "prompt"
439451
else:
440-
exported_model_path = os.path.join(model_result.log_dir, self.export_path)
441-
return Taskflow(
442-
"text_classification",
443-
mode="finetune",
444-
problem_type=self.problem_type,
445-
task_path=exported_model_path,
446-
id2label=self.id2label,
452+
mode = "finetune"
453+
454+
return Taskflow(
455+
"text_classification",
456+
mode=mode,
457+
is_static_model=True,
458+
problem_type=self.problem_type,
459+
task_path=export_path,
460+
multilabel_threshold=self.multilabel_threshold,
461+
batch_size=batch_size,
462+
max_length=max_length,
463+
precision=precision,
464+
)
465+
466+
def export(self, export_path, trial_id=None):
467+
"""
468+
Export the model from a certain `trial_id` to the given file path.
469+
470+
Args:
471+
export_path (str, required): the filepath to export to
472+
trial_id (int, required): use the `trial_id` to select the model to export. Defaults to the best model selected by `metric_for_best_model`
473+
"""
474+
475+
model_result = self._get_model_result(trial_id=trial_id)
476+
model_config = model_result.metrics["config"]["candidates"]
477+
trial_id = model_result.metrics["trial_id"]
478+
479+
if os.path.exists(export_path):
480+
logger.info(
481+
f"Export path for {trial_id} already exists: ({export_path}). The model parameter files will be overwritten."
447482
)
483+
484+
# construct trainer
485+
trainer = self._construct_trainer(model_config)
486+
trainer.load_state_dict_from_checkpoint(
487+
resume_from_checkpoint=os.path.join(model_result.log_dir, self.save_path)
488+
)
489+
490+
# save static model
491+
if model_config["trainer_type"] == "PromptTrainer":
492+
trainer.export_model(export_path)
493+
trainer.model.plm.save_pretrained(os.path.join(export_path, "plm"))
494+
else:
495+
if trainer.model.init_config["init_class"] in ["ErnieMForSequenceClassification"]:
496+
input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")]
497+
else:
498+
input_spec = [
499+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
500+
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"),
501+
]
502+
export_model(model=trainer.model, input_spec=input_spec, path=export_path)
503+
# save tokenizer
504+
trainer.tokenizer.save_pretrained(export_path)
505+
506+
# save id2label
507+
shutil.copyfile(os.path.join(self.output_dir, "id2label.json"), os.path.join(export_path, "id2label.json"))
508+
509+
logger.info(f"Exported {trial_id} to {export_path}")

paddlenlp/taskflow/text_classification.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,11 @@ def _preprocess(self, inputs: Union[str, List[str]]) -> Dict[str, Any]:
266266
collator = PromptDataCollatorWithPadding(
267267
self._tokenizer, padding=True, return_tensors="np", return_attention_mask=True
268268
)
269-
template_inputs = [self._template({"text_a": x}) for x in inputs]
269+
part_text = "text"
270+
for part in self._template.prompt:
271+
if "text" in part:
272+
part_text = part["text"]
273+
template_inputs = [self._template({part_text: x}) for x in inputs]
270274
batches = [template_inputs[idx : idx + batch_size] for idx in range(0, len(template_inputs), batch_size)]
271275
else:
272276
raise NotImplementedError(

0 commit comments

Comments
 (0)