Processing in predict_step() requires access to a DataModule attribute #7884
-
Hi! In my LightningDataModule, I apply preprocessing transforms to my input data before feeding it to my dataloader. In the same datamodule, I also defined the postprocessing transforms to apply after the inference.
I want to apply these post_transforms to my inference outputs in
Thanks in advance for your suggestions :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hi, you should be able to access the datamodule inside the LightningModule. def predict_step(self, batch: Any, batch_idx: int):
batch["pred"] = self(batch)
self.datamodule.post_transforms(batch) Also, another tipp: Better use |
Beta Was this translation helpful? Give feedback.
Hi, you should be able to access the datamodule inside the LightningModule.
Try
Also, another tipp: Better use
self()
instead ofself.forward()
(generally in PyTorch).