@@ -160,9 +160,10 @@ class Trainer:
160
160
train_dataset (`paddle.io.Dataset` or `paddle.io.IterableDataset`, *optional*):
161
161
The dataset to use for training. If it is an `datasets.Dataset`, columns not accepted by the
162
162
`model.forward()` method are automatically removed.
163
- eval_dataset (`paddle.io.Dataset`, *optional*):
164
- The dataset to use for evaluation. If it is an `datasets.Dataset`, columns not accepted by the
165
- `model.forward()` method are automatically removed.
163
+ eval_dataset (Union[`paddle.io.Dataset`, Dict[str, `paddle.io.Dataset`]], *optional*):
164
+ The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
165
+ `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
166
+ dataset prepending the dictionary key to the metric name.
166
167
tokenizer ([`PretrainedTokenizer`], *optional*):
167
168
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
168
169
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
@@ -201,7 +202,7 @@ def __init__(
201
202
args : TrainingArguments = None ,
202
203
data_collator : Optional [DataCollator ] = None ,
203
204
train_dataset : Optional [Dataset ] = None ,
204
- eval_dataset : Optional [Dataset ] = None ,
205
+ eval_dataset : Union [Dataset , Dict [ str , Dataset ] ] = None ,
205
206
tokenizer : Optional [PretrainedTokenizer ] = None ,
206
207
compute_metrics : Optional [Callable [[EvalPrediction ], Dict ]] = None ,
207
208
callbacks : Optional [List [TrainerCallback ]] = None ,
@@ -834,7 +835,15 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
834
835
835
836
metrics = None
836
837
if self .control .should_evaluate :
837
- metrics = self .evaluate (ignore_keys = ignore_keys_for_eval )
838
+ if isinstance (self .eval_dataset , dict ):
839
+ for eval_dataset_name , eval_dataset in self .eval_dataset .items ():
840
+ metrics = self .evaluate (
841
+ eval_dataset = eval_dataset ,
842
+ ignore_keys = ignore_keys_for_eval ,
843
+ metric_key_prefix = f"eval_{ eval_dataset_name } " ,
844
+ )
845
+ else :
846
+ metrics = self .evaluate (ignore_keys = ignore_keys_for_eval )
838
847
839
848
if self .control .should_save :
840
849
self ._save_checkpoint (model , metrics = metrics )
0 commit comments