Skip to content

Commit e182559

Browse files
committed
updated docs
1 parent 9b99a02 commit e182559

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,43 +49,44 @@ from torchvision.datasets import MNIST
4949
class CoolModel(ptl.LightningModule):
5050

5151
def __init(self):
52+
super(CoolModel, self).__init__()
5253
# not the best model...
53-
self.l1 = torch.nn.Linear(28*28, 10)
54-
54+
self.l1 = torch.nn.Linear(28 * 28, 10)
55+
5556
def forward(self, x):
5657
return torch.relu(self.l1(x))
57-
58+
5859
def my_loss(self, y_hat, y):
5960
return F.cross_entropy(y_hat, y)
60-
61+
6162
def training_step(self, batch, batch_nb):
6263
x, y = batch
6364
y_hat = self.forward(x)
6465
return {'tng_loss': self.my_loss(y_hat, y)}
65-
66+
6667
def validation_step(self, batch, batch_nb):
6768
x, y = batch
6869
y_hat = self.forward(x)
6970
return {'val_loss': self.my_loss(y_hat, y)}
70-
71+
7172
def validation_end(self, outputs):
7273
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
7374
return avg_loss
74-
75+
7576
def configure_optimizers(self):
7677
return [torch.optim.Adam(self.parameters(), lr=0.02)]
77-
78+
7879
@ptl.data_loader
7980
def tng_dataloader(self):
8081
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
8182

8283
@ptl.data_loader
8384
def val_dataloader(self):
8485
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
85-
86+
8687
@ptl.data_loader
8788
def test_dataloader(self):
88-
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
89+
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
8990
```
9091

9192
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)

docs/LightningModule/RequiredTrainerInterface.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,43 +38,44 @@ from torchvision.datasets import MNIST
3838
class CoolModel(ptl.LightningModule):
3939

4040
def __init(self):
41+
super(CoolModel, self).__init__()
4142
# not the best model...
42-
self.l1 = torch.nn.Linear(28*28, 10)
43-
43+
self.l1 = torch.nn.Linear(28 * 28, 10)
44+
4445
def forward(self, x):
4546
return torch.relu(self.l1(x))
46-
47+
4748
def my_loss(self, y_hat, y):
4849
return F.cross_entropy(y_hat, y)
49-
50+
5051
def training_step(self, batch, batch_nb):
5152
x, y = batch
5253
y_hat = self.forward(x)
5354
return {'tng_loss': self.my_loss(y_hat, y)}
54-
55+
5556
def validation_step(self, batch, batch_nb):
5657
x, y = batch
5758
y_hat = self.forward(x)
5859
return {'val_loss': self.my_loss(y_hat, y)}
59-
60+
6061
def validation_end(self, outputs):
6162
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
6263
return avg_loss
63-
64+
6465
def configure_optimizers(self):
6566
return [torch.optim.Adam(self.parameters(), lr=0.02)]
66-
67+
6768
@ptl.data_loader
6869
def tng_dataloader(self):
6970
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
7071

7172
@ptl.data_loader
7273
def val_dataloader(self):
7374
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
74-
75+
7576
@ptl.data_loader
7677
def test_dataloader(self):
77-
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
78+
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
7879
```
7980

8081
---

0 commit comments

Comments
 (0)