Skip to content

Commit 12d0ef7

Browse files
committed
add: pytorch-lightning examples
1 parent a06b864 commit 12d0ef7

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import os
2+
3+
import pytorch_lightning as pl
4+
import torch
5+
from torch import nn
6+
from torch.optim import AdamW
7+
from torch.utils.data import DataLoader
8+
from torchvision.datasets import MNIST
9+
from torchvision.transforms import ToTensor
10+
11+
from pytorch_optimizer import Lookahead
12+
13+
14+
class LitAutoEncoder(pl.LightningModule):
15+
def __init__(self):
16+
super().__init__()
17+
18+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
19+
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
20+
21+
def training_step(self, batch, batch_idx):
22+
x, y = batch
23+
x = x.view(x.size(0), -1)
24+
25+
z = self.encoder(x)
26+
x_hat = self.decoder(z)
27+
28+
loss = nn.functional.mse_loss(x_hat, x)
29+
30+
self.log('train_loss', loss)
31+
32+
return loss
33+
34+
def configure_optimizers(self):
35+
return Lookahead(AdamW(self.parameters(), lr=1e-3), k=5, alpha=0.5)
36+
37+
38+
def main():
39+
train_dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
40+
train_loader = DataLoader(train_dataset)
41+
42+
autoencoder = LitAutoEncoder()
43+
autoencoder.train()
44+
45+
if torch.cuda.is_available():
46+
autoencoder.cuda()
47+
48+
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
49+
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
50+
51+
52+
if __name__ == '__main__':
53+
main()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
3+
import pytorch_lightning as pl
4+
import torch
5+
from torch import nn
6+
from torch.utils.data import DataLoader
7+
from torchvision.datasets import MNIST
8+
from torchvision.transforms import ToTensor
9+
10+
from pytorch_optimizer import SophiaH
11+
12+
13+
class LitAutoEncoder(pl.LightningModule):
14+
def __init__(self):
15+
super().__init__()
16+
17+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
18+
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
19+
20+
self.automatic_optimization = False
21+
22+
def training_step(self, batch, batch_idx):
23+
opt = self.optimizers()
24+
opt.zero_grad()
25+
26+
x, y = batch
27+
x = x.view(x.size(0), -1)
28+
29+
z = self.encoder(x)
30+
x_hat = self.decoder(z)
31+
32+
loss = nn.functional.mse_loss(x_hat, x)
33+
34+
self.manual_backward(loss, create_graph=True)
35+
opt.step()
36+
37+
self.log('train_loss', loss)
38+
39+
return loss
40+
41+
def configure_optimizers(self):
42+
return SophiaH(self.parameters())
43+
44+
45+
def main():
46+
train_dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
47+
train_loader = DataLoader(train_dataset)
48+
49+
autoencoder = LitAutoEncoder()
50+
autoencoder.train()
51+
52+
if torch.cuda.is_available():
53+
autoencoder.cuda()
54+
55+
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
56+
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
57+
58+
59+
if __name__ == '__main__':
60+
main()

0 commit comments

Comments
 (0)