How to obtain predictions from LightningCLI
using the predict
subcommand?
#10509
Replies: 2 comments
-
Create a Callback class that implements on_predict_epoch_end. This method receives as argument the predicted results, so there you can do whatever is needed. Then add this callback to the CLI either as non-configurable by using trainer_defaults as |
Beta Was this translation helpful? Give feedback.
-
look like the suggested method as of may 2023 is to use Example from the docs: import torch
from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_batch_end(
self, trainer, pl_module', prediction, batch_indices, batch, batch_idx, dataloader_idx
):
torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False) one would replace the last three lines with a |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
First of all, I like to say it's very refreshing to see the changes to the library since I last used it in v1.1! Kudos to the team for all the great work 👏
pytorch-lightning
version: 1.5.0LightningCLI
I came across the
LightningCLI
class and found it to help reduce a lot of argument parsing boilerplate.My script looks like this:
I was able to use it to
.fit
and.test
my model, but I am not sure how to use.predict
.When I pull up the help menu for this script, for the
predict
subcommand, I see that there's a flag to return predictions from the model:However, I cannot find any documentation on using
predict
with theLightningCLI
or on how to obtain a handle to the predictions.Appreciate guidance on how to do so and thanks in advance! 😊
Beta Was this translation helpful? Give feedback.
All reactions