Error with predict() #7747
-
I am trying to predict on my test data but it throws an error: To help you get context: The # test dataloader
test_dataset = torchvision.datasets.ImageFolder(
'path/to/test_data_labelled_folders',
transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
)
testloader = torch.utils.data.DataLoader(
test_dataset, num_workers=4, batch_size=4)
# loading previously trained model
model = LitClassifier.load_from_checkpoint('path/to/checkpoint.ckpt') #LitClassifier is a Lightningmodule
# calling predict function
trainer = pl.Trainer()
trainer.predict(model=model, dataloaders=[testloader], return_predictions=True) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Did you overwrite the So you have two choices: Remove the labels from you predict data or overwrite the predict step to ignore them :) |
Beta Was this translation helpful? Give feedback.
Did you overwrite the
predict_step
? By default it just feeds the whole batch throughforward
(which with the image folder also includes the label and therefore is a list)So you have two choices: Remove the labels from you predict data or overwrite the predict step to ignore them :)