@@ -407,9 +407,14 @@ def test_trainer_music_streaming(task_and_metrics):
407407 assert eval_metrics ["eval_/loss" ] is not None
408408
409409 assert predictions is not None
410+
411+ DEFAULT_PREDICT_TOP_K = 100
412+
410413 # 1000 is the total samples in the testing data
411414 if isinstance (task , tr .NextItemPredictionTask ):
412- assert predictions .predictions .shape == (1000 , task .target_dim )
415+ top_predicted_item_ids , top_prediction_scores = predictions .predictions
416+ assert top_predicted_item_ids .shape == (1000 , DEFAULT_PREDICT_TOP_K )
417+ assert top_prediction_scores .shape == (1000 , DEFAULT_PREDICT_TOP_K )
413418 else :
414419 assert predictions .predictions .shape == (1000 ,)
415420
@@ -573,17 +578,79 @@ def test_trainer_with_multiple_tasks(schema_type):
573578 assert predictions .predictions ["click/binary_classification_task" ].shape == (1000 ,)
574579
575580
576- def test_trainer_trop_k_with_wrong_task ():
581+ @pytest .mark .parametrize ("predict_top_k" , [20 , None , "default" ])
582+ def test_trainer_predict_topk (predict_top_k ):
583+ DEFAULT_PREDICT_TOP_K = 100
584+
577585 data = tr .data .music_streaming_testing_data
578586 schema = data .schema
579587 batch_size = 16
580- predict_top_k = 20
581588
582- task = tr .BinaryClassificationTask ("click" , summary_type = "mean" )
589+ task = tr .NextItemPredictionTask (weight_tying = True )
590+ inputs = tr .TabularSequenceFeatures .from_schema (
591+ schema ,
592+ max_sequence_length = 20 ,
593+ d_output = 64 ,
594+ masking = "clm" ,
595+ )
596+ transformer_config = tconf .XLNetConfig .build (64 , 4 , 2 , 20 )
597+ model = transformer_config .to_torch_model (inputs , task )
598+
599+ additional_args = {}
600+ if not isinstance (predict_top_k , str ):
601+ additional_args ["predict_top_k" ] = predict_top_k
602+
603+ args = trainer .T4RecTrainingArguments (
604+ output_dir = "." ,
605+ num_train_epochs = 1 ,
606+ per_device_train_batch_size = batch_size ,
607+ per_device_eval_batch_size = batch_size // 2 ,
608+ data_loader_engine = "merlin_dataloader" ,
609+ max_sequence_length = 20 ,
610+ report_to = [],
611+ debug = ["r" ],
612+ ** additional_args ,
613+ )
614+
615+ recsys_trainer = tr .Trainer (
616+ model = model ,
617+ args = args ,
618+ schema = schema ,
619+ train_dataset_or_path = data .path ,
620+ eval_dataset_or_path = data .path ,
621+ test_dataset_or_path = data .path ,
622+ compute_metrics = True ,
623+ )
624+
625+ outputs = recsys_trainer .predict (data .path )
626+
627+ if predict_top_k is None :
628+ assert outputs .predictions .shape [1 ] == 10001
629+ else :
630+ if predict_top_k == "default" :
631+ predict_top_k = DEFAULT_PREDICT_TOP_K
632+
633+ pred_item_ids , pred_scores = outputs .predictions
634+ assert len (pred_item_ids .shape ) == 2
635+ assert pred_item_ids .shape [1 ] == predict_top_k
636+ assert len (pred_scores .shape ) == 2
637+ assert pred_scores .shape [1 ] == predict_top_k
638+
639+
640+ @pytest .mark .parametrize ("predict_top_k" , [15 , 20 , 30 , None ])
641+ @pytest .mark .parametrize ("top_k" , [20 , None ])
642+ def test_trainer_predict_top_k_x_top_k (predict_top_k , top_k ):
643+ data = tr .data .music_streaming_testing_data
644+ schema = data .schema
645+ batch_size = 16
646+
647+ task = tr .NextItemPredictionTask (weight_tying = True )
648+
583649 inputs = tr .TabularSequenceFeatures .from_schema (
584650 schema ,
585651 max_sequence_length = 20 ,
586652 d_output = 64 ,
653+ masking = "clm" ,
587654 )
588655 transformer_config = tconf .XLNetConfig .build (64 , 4 , 2 , 20 )
589656 model = transformer_config .to_torch_model (inputs , task )
@@ -609,10 +676,28 @@ def test_trainer_trop_k_with_wrong_task():
609676 test_dataset_or_path = data .path ,
610677 compute_metrics = True ,
611678 )
612- with pytest .raises (AssertionError ) as excinfo :
613- recsys_trainer .predict (data .path )
614679
615- assert "Top-k prediction is specific to NextItemPredictionTask" in str (excinfo .value )
680+ model .top_k = top_k
681+
682+ if predict_top_k and top_k and predict_top_k > top_k :
683+ with pytest .raises (ValueError ) as excinfo :
684+ recsys_trainer .predict (data .path )
685+ assert "The args.predict_top_k should not be larger than model.top_k" in str (excinfo .value )
686+
687+ else :
688+ outputs = recsys_trainer .predict (data .path )
689+
690+ if predict_top_k or top_k :
691+ expected_top_k = predict_top_k or top_k
692+
693+ pred_item_ids , pred_scores = outputs .predictions
694+ assert len (pred_item_ids .shape ) == 2
695+ assert pred_item_ids .shape [1 ] == expected_top_k
696+ assert len (pred_scores .shape ) == 2
697+ assert pred_scores .shape [1 ] == expected_top_k
698+ else :
699+ ITEM_CARDINALITY = 10001
700+ assert outputs .predictions .shape [1 ] == ITEM_CARDINALITY
616701
617702
618703def test_trainer_with_pretrained_embeddings ():
0 commit comments