Skip to content

Commit fa18d20

Browse files
authored
[Trainer] Support Trainer to use dataset dict for evaluation in training (#4778)
* support dataset dict for evaluation during training
1 parent bb76b09 commit fa18d20

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

docs/trainer.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
129129
The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
130130
`model.forward()` method are automatically removed.
131131

132-
eval_dataset(`paddle.io.Dataset`,可选):
132+
eval_dataset(`paddle.io.Dataset``Dict[str, paddle.io.Dataset]`,可选):
133133
用于评估的数据集。如果是 `datasets.Dataset`,那么
134134
`model.forward()` 不需要的输入字段会被自动删除。
135+
如果它是一个字典,则将对字典中每个数据集进行评估,
136+
并将字典中的键添加到评估指标名称前。
135137

136-
The dataset to use for evaluation.
138+
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
139+
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
140+
dataset prepending the dictionary key to the metric name.
137141

138142
tokenizer([`PretrainedTokenizer`],可选):
139143
用于数据预处理的tokenizer。如果传入,将用于自动Pad输入

paddlenlp/trainer/trainer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ class Trainer:
160160
train_dataset (`paddle.io.Dataset` or `paddle.io.IterableDataset`, *optional*):
161161
The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
162162
`model.forward()` method are automatically removed.
163-
eval_dataset (`paddle.io.Dataset`, *optional*):
164-
The dataset to use for evaluation. If it is an `datasets.Dataset`, columns not accepted by the
165-
`model.forward()` method are automatically removed.
163+
eval_dataset (Union[`paddle.io.Dataset`, Dict[str, `paddle.io.Dataset`]], *optional*):
164+
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
165+
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
166+
dataset prepending the dictionary key to the metric name.
166167
tokenizer ([`PretrainedTokenizer`], *optional*):
167168
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
168169
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
@@ -201,7 +202,7 @@ def __init__(
201202
args: TrainingArguments = None,
202203
data_collator: Optional[DataCollator] = None,
203204
train_dataset: Optional[Dataset] = None,
204-
eval_dataset: Optional[Dataset] = None,
205+
eval_dataset: Union[Dataset, Dict[str, Dataset]] = None,
205206
tokenizer: Optional[PretrainedTokenizer] = None,
206207
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
207208
callbacks: Optional[List[TrainerCallback]] = None,
@@ -834,7 +835,15 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
834835

835836
metrics = None
836837
if self.control.should_evaluate:
837-
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
838+
if isinstance(self.eval_dataset, dict):
839+
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
840+
metrics = self.evaluate(
841+
eval_dataset=eval_dataset,
842+
ignore_keys=ignore_keys_for_eval,
843+
metric_key_prefix=f"eval_{eval_dataset_name}",
844+
)
845+
else:
846+
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
838847

839848
if self.control.should_save:
840849
self._save_checkpoint(model, metrics=metrics)

0 commit comments

Comments
 (0)