Skip to content

Processing in predict_step() requires access to a DataModule attribute #7884

Discussion options

You must be logged in to vote

Hi, you should be able to access the datamodule inside the LightningModule.
Try

def predict_step(self, batch: Any, batch_idx: int):
        batch["pred"] = self(batch)
        self.datamodule.post_transforms(batch)

Also, another tipp: Better use self() instead of self.forward() (generally in PyTorch).

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@dianemarquette
Comment options

@abuwady
Comment options

Answer selected by dianemarquette
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment