Skip to content

Commit d9f67f2

Browse files
Fixed OOM issues when evaluating/predicting (#721)
* Fixed OOM issues when evaluating/predicting * Fixed test * Improved error message * Fixed test
1 parent 5fa1259 commit d9f67f2

File tree

3 files changed

+149
-45
lines changed

3 files changed

+149
-45
lines changed

tests/unit/torch/test_trainer.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,14 @@ def test_trainer_music_streaming(task_and_metrics):
407407
assert eval_metrics["eval_/loss"] is not None
408408

409409
assert predictions is not None
410+
411+
DEFAULT_PREDICT_TOP_K = 100
412+
410413
# 1000 is the total samples in the testing data
411414
if isinstance(task, tr.NextItemPredictionTask):
412-
assert predictions.predictions.shape == (1000, task.target_dim)
415+
top_predicted_item_ids, top_prediction_scores = predictions.predictions
416+
assert top_predicted_item_ids.shape == (1000, DEFAULT_PREDICT_TOP_K)
417+
assert top_prediction_scores.shape == (1000, DEFAULT_PREDICT_TOP_K)
413418
else:
414419
assert predictions.predictions.shape == (1000,)
415420

@@ -573,17 +578,79 @@ def test_trainer_with_multiple_tasks(schema_type):
573578
assert predictions.predictions["click/binary_classification_task"].shape == (1000,)
574579

575580

576-
def test_trainer_trop_k_with_wrong_task():
581+
@pytest.mark.parametrize("predict_top_k", [20, None, "default"])
582+
def test_trainer_predict_topk(predict_top_k):
583+
DEFAULT_PREDICT_TOP_K = 100
584+
577585
data = tr.data.music_streaming_testing_data
578586
schema = data.schema
579587
batch_size = 16
580-
predict_top_k = 20
581588

582-
task = tr.BinaryClassificationTask("click", summary_type="mean")
589+
task = tr.NextItemPredictionTask(weight_tying=True)
590+
inputs = tr.TabularSequenceFeatures.from_schema(
591+
schema,
592+
max_sequence_length=20,
593+
d_output=64,
594+
masking="clm",
595+
)
596+
transformer_config = tconf.XLNetConfig.build(64, 4, 2, 20)
597+
model = transformer_config.to_torch_model(inputs, task)
598+
599+
additional_args = {}
600+
if not isinstance(predict_top_k, str):
601+
additional_args["predict_top_k"] = predict_top_k
602+
603+
args = trainer.T4RecTrainingArguments(
604+
output_dir=".",
605+
num_train_epochs=1,
606+
per_device_train_batch_size=batch_size,
607+
per_device_eval_batch_size=batch_size // 2,
608+
data_loader_engine="merlin_dataloader",
609+
max_sequence_length=20,
610+
report_to=[],
611+
debug=["r"],
612+
**additional_args,
613+
)
614+
615+
recsys_trainer = tr.Trainer(
616+
model=model,
617+
args=args,
618+
schema=schema,
619+
train_dataset_or_path=data.path,
620+
eval_dataset_or_path=data.path,
621+
test_dataset_or_path=data.path,
622+
compute_metrics=True,
623+
)
624+
625+
outputs = recsys_trainer.predict(data.path)
626+
627+
if predict_top_k is None:
628+
assert outputs.predictions.shape[1] == 10001
629+
else:
630+
if predict_top_k == "default":
631+
predict_top_k = DEFAULT_PREDICT_TOP_K
632+
633+
pred_item_ids, pred_scores = outputs.predictions
634+
assert len(pred_item_ids.shape) == 2
635+
assert pred_item_ids.shape[1] == predict_top_k
636+
assert len(pred_scores.shape) == 2
637+
assert pred_scores.shape[1] == predict_top_k
638+
639+
640+
@pytest.mark.parametrize("predict_top_k", [15, 20, 30, None])
641+
@pytest.mark.parametrize("top_k", [20, None])
642+
def test_trainer_predict_top_k_x_top_k(predict_top_k, top_k):
643+
data = tr.data.music_streaming_testing_data
644+
schema = data.schema
645+
batch_size = 16
646+
647+
task = tr.NextItemPredictionTask(weight_tying=True)
648+
583649
inputs = tr.TabularSequenceFeatures.from_schema(
584650
schema,
585651
max_sequence_length=20,
586652
d_output=64,
653+
masking="clm",
587654
)
588655
transformer_config = tconf.XLNetConfig.build(64, 4, 2, 20)
589656
model = transformer_config.to_torch_model(inputs, task)
@@ -609,10 +676,28 @@ def test_trainer_trop_k_with_wrong_task():
609676
test_dataset_or_path=data.path,
610677
compute_metrics=True,
611678
)
612-
with pytest.raises(AssertionError) as excinfo:
613-
recsys_trainer.predict(data.path)
614679

615-
assert "Top-k prediction is specific to NextItemPredictionTask" in str(excinfo.value)
680+
model.top_k = top_k
681+
682+
if predict_top_k and top_k and predict_top_k > top_k:
683+
with pytest.raises(ValueError) as excinfo:
684+
recsys_trainer.predict(data.path)
685+
assert "The args.predict_top_k should not be larger than model.top_k" in str(excinfo.value)
686+
687+
else:
688+
outputs = recsys_trainer.predict(data.path)
689+
690+
if predict_top_k or top_k:
691+
expected_top_k = predict_top_k or top_k
692+
693+
pred_item_ids, pred_scores = outputs.predictions
694+
assert len(pred_item_ids.shape) == 2
695+
assert pred_item_ids.shape[1] == expected_top_k
696+
assert len(pred_scores.shape) == 2
697+
assert pred_scores.shape[1] == expected_top_k
698+
else:
699+
ITEM_CARDINALITY = 10001
700+
assert outputs.predictions.shape[1] == ITEM_CARDINALITY
616701

617702

618703
def test_trainer_with_pretrained_embeddings():

transformers4rec/config/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ class T4RecTrainingArguments(TrainingArguments):
3838
predict_top_k: Option[int], int
3939
Truncate recommendation list to the highest top-K predicted items,
4040
(do not affect evaluation metrics computation),
41-
this parameter is specific to NextItemPredictionTask.
42-
by default 0
41+
This parameter is specific to NextItemPredictionTask and only affects
42+
model.predict() and model.evaluate(), which both call `Trainer.evaluation_loop`.
43+
By default 100.
4344
log_predictions : Optional[bool], bool
4445
log predictions, labels and metadata features each --compute_metrics_each_n_steps
4546
(for test set).
@@ -90,10 +91,10 @@ class T4RecTrainingArguments(TrainingArguments):
9091
)
9192

9293
predict_top_k: int = field(
93-
default=0,
94+
default=100,
9495
metadata={
9596
"help": "Truncate recommendation list to the highest top-K predicted items (do not affect evaluation metrics computation), "
96-
"this parameter is specific to NextItemPredictionTask."
97+
"this parameter is specific to NextItemPredictionTask. Default is 100."
9798
},
9899
)
99100

transformers4rec/torch/trainer.py

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -528,50 +528,68 @@ def evaluation_loop(
528528
if labels_host is None
529529
else nested_concat(labels_host, labels, padding_index=0)
530530
)
531-
if preds is not None and self.args.predict_top_k > 0:
532-
if self.model.top_k:
533-
raise ValueError(
534-
"you cannot set top_k argument in the model class and the, "
535-
"predict_top_k in the trainer at the same time. Please ensure setting "
536-
"only predict_top_k"
537-
)
531+
532+
if (
533+
preds is not None
534+
and any(isinstance(x, NextItemPredictionTask) for x in model.prediction_tasks)
535+
and (self.args.predict_top_k or self.model.top_k)
536+
):
538537
# get outputs of next-item scores
539538
if isinstance(preds, dict):
540-
assert any(
541-
isinstance(x, NextItemPredictionTask) for x in model.prediction_tasks
542-
), "Top-k prediction is specific to NextItemPredictionTask, "
543-
"Please ensure `self.args.predict_top_k == 0` "
544539
pred_next_item = preds["next-item"]
545540
else:
546-
assert isinstance(
547-
model.prediction_tasks[0], NextItemPredictionTask
548-
), "Top-k prediction is specific to NextItemPredictionTask, "
549-
"Please ensure `self.args.predict_top_k == 0` "
550541
pred_next_item = preds
551542

552-
preds_sorted_item_scores, preds_sorted_item_ids = torch.topk(
553-
pred_next_item, k=self.args.predict_top_k, dim=-1
554-
)
555-
self._maybe_log_predictions(
556-
labels,
557-
preds_sorted_item_ids,
558-
preds_sorted_item_scores,
559-
# outputs["pred_metadata"],
560-
metrics_results_detailed,
561-
metric_key_prefix,
562-
)
563-
# The output predictions will be a tuple with the ranked top-n item ids,
564-
# and item recommendation scores
565-
if isinstance(preds, dict):
566-
preds["next-item"] = (
567-
preds_sorted_item_ids,
568-
preds_sorted_item_scores,
543+
preds_sorted_item_scores = None
544+
preds_sorted_item_ids = None
545+
546+
if self.model.top_k is not None and isinstance(pred_next_item, (list, tuple)):
547+
preds_sorted_item_scores, preds_sorted_item_ids = pred_next_item
548+
549+
if self.args.predict_top_k:
550+
if self.args.predict_top_k > self.model.top_k:
551+
raise ValueError(
552+
"The args.predict_top_k should not be larger than model.top_k. "
553+
"The model.top_k is available to support inference (e.g. when "
554+
"serving with Triton Inference Server) to return only the top-k "
555+
"predicted items ids and their scores."
556+
"When doing offline predictions with `trainer.predict(), "
557+
"if you set model.top_k, the model will also limit the number of "
558+
"predictions output from trainer.predict(). "
559+
"In that case, you want either to reduce args.predict_top_k or "
560+
"increase model.top_k, so that args.predict_top_k is "
561+
"not larger than model.top_k."
562+
)
563+
preds_sorted_item_scores = preds_sorted_item_scores[
564+
:, : self.args.predict_top_k
565+
]
566+
preds_sorted_item_ids = preds_sorted_item_ids[:, : self.args.predict_top_k]
567+
elif self.args.predict_top_k:
568+
preds_sorted_item_scores, preds_sorted_item_ids = torch.topk(
569+
pred_next_item, k=self.args.predict_top_k, dim=-1
569570
)
570-
else:
571-
preds = (
571+
572+
if preds_sorted_item_scores is not None:
573+
self._maybe_log_predictions(
574+
labels,
572575
preds_sorted_item_ids,
573576
preds_sorted_item_scores,
577+
# outputs["pred_metadata"],
578+
metrics_results_detailed,
579+
metric_key_prefix,
574580
)
581+
# The output predictions will be a tuple with the ranked top-n item ids,
582+
# and item recommendation scores
583+
if isinstance(preds, dict):
584+
preds["next-item"] = (
585+
preds_sorted_item_ids,
586+
preds_sorted_item_scores,
587+
)
588+
else:
589+
preds = (
590+
preds_sorted_item_ids,
591+
preds_sorted_item_scores,
592+
)
575593

576594
preds_host = (
577595
preds

0 commit comments

Comments
 (0)