How to use predict function to return predictions #8038
-
How to get predictions class OurModel(LightningModule):
def __init__(self):
super(OurModel,self).__init__()
self.layer = MyModelV3()
def forward(self,x):
return self.layer(x)
def train_dataloader(self):
return DataLoader(DataReader(train_df))
def training_step(self,batch,batch_idx):
return loss
def test_dataloader(self):
return DataLoader(DataReader(test_df))
def test_step(self,batch,batch_idx):
image,label=batch
out=self(image)
loss=self.criterion(out,label)
return loss
def predict(self, batch):
return self(batch) I am not sure, how to use predict function. How to define data loader for predict function. I want to get predictions for |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Hi @talhaanwarch, in order to get predictions from a data loader you need to implement predict_step in your Hope that helps 😃 |
Beta Was this translation helpful? Give feedback.
-
Dear @talhaanwarch, To create a dataloder, just do predict_dataloader = DataLoader(predict_dataset)
class Model(LightningModule):
def __init__(self, ...):
super().__init__()
self.save_hyperparameters()
self.model = ...
def forward(self, batch):
return self.model(batch)
def predict_step(self, batch, batch_idx, dataloder_idx = None):
return self(batch)
trainer = Trainer(...)
predictions = trainer.predict(model, dataloaders=predict_dataloader) |
Beta Was this translation helpful? Give feedback.
Hi @talhaanwarch, in order to get predictions from a data loader you need to implement predict_step in your
LightningModule
(docs here: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#predict-step). You would then be able to callTrainer.predict
with the dataloader you want use following the API here: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.html#pytorch_lightning.trainer.trainer.Trainer.predictHope that helps 😃