Skip to content
Discussion options

You must be logged in to vote

@karndeepsingh To use Trainer.predict(), You must have predict_dataloader() defined in your LightningModule or LightningDataModule as the error message states:

MisconfigurationException: No `predict_dataloader()` method defined to run `Trainer.predict`.

If you'd like to run inference on your test set, you just need to define predict_dataloader() with your test set:

    def predict_dataloader(self):
        return torch.utils.data.DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        num_workers=4,
        shuffle=False)

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@karndeepsingh
Comment options

@akihironitta
Comment options

Answer selected by akihironitta
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment