How to predict on the test dataset using trainer.predict()? #13568
-
Hi I have trained the model using trainer and was trying to use trainer.predict() method to predict on the datamodule. But it throws the following error:
I have the following dataloader:
I have following model defined:
Please, help to predict on the testdata. How I can leverage trainer to predict on the test data and get classification report for the predicted output? Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
@karndeepsingh To use
If you'd like to run inference on your test set, you just need to define def predict_dataloader(self):
return torch.utils.data.DataLoader(
self.test_dataset,
batch_size=self.batch_size,
num_workers=4,
shuffle=False) |
Beta Was this translation helpful? Give feedback.
@karndeepsingh To use
Trainer.predict()
, You must havepredict_dataloader()
defined in your LightningModule or LightningDataModule as the error message states:If you'd like to run inference on your test set, you just need to define
predict_dataloader()
with your test set: