Skip to content
Discussion options

You must be logged in to vote

Dear @ItamarKanter,

I believe this great and pretty pythonic !

You could do this to make it slightly cleaner.

class Model(LightningModule)

    def common_step(self, batch, batch_idx, stage):
        logits = self(batch[0])
        loss = self.compute_loss(logits, batch[1])
        self.log(f"{state}_loss", loss)
        return loss

    def training_step(self, batch, batch_idx):
        return self.common_step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        self.common_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        self.common_step(batch, batch_idx, "test")

Best,
T.C

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ItamarKanter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment