Skip to content

Commit 2f4797e

Browse files
authored
adding simple PL examples (#6)
1 parent faca4ae commit 2f4797e

File tree

7 files changed

+117
-1
lines changed

7 files changed

+117
-1
lines changed

examples/__init__.py

Whitespace-only changes.

examples/sample_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
import torch.nn as nn
3+
from lightning import LightningModule
4+
from torch.nn.functional import mse_loss
5+
6+
7+
class LitAutoEncoder(LightningModule):
8+
def __init__(self):
9+
super().__init__()
10+
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
11+
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
12+
13+
def forward(self, x):
14+
# in lightning, forward defines the prediction/inference actions
15+
return self.encoder(x)
16+
17+
def training_step(self, batch, batch_idx):
18+
# training_step defines the train loop. It is independent of forward
19+
x, _ = batch
20+
x = x.view(x.size(0), -1)
21+
z = self.encoder(x)
22+
x_hat = self.decoder(z)
23+
loss = mse_loss(x_hat, x)
24+
self.log("train_loss", loss)
25+
return loss
26+
27+
def configure_optimizers(self):
28+
return torch.optim.AdamW(self.parameters(), lr=1e-4)

examples/train_callback.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch.utils.data as data
2+
import torchvision as tv
3+
from lightning import Callback, Trainer
4+
from litmodels import upload_model
5+
from sample_model import LitAutoEncoder
6+
7+
8+
class UploadModelCallback(Callback):
9+
def on_train_epoch_end(self, trainer, pl_module):
10+
# Get the best model path from the checkpoint callback
11+
best_model_path = trainer.checkpoint_callback.best_model_path
12+
if best_model_path:
13+
print(f"Uploading model: {best_model_path}")
14+
upload_model(path=best_model_path, name="jirka/kaggle/lit-auto-encoder-callback")
15+
16+
17+
if __name__ == "__main__":
18+
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
19+
train, val = data.random_split(dataset, [55000, 5000])
20+
21+
autoencoder = LitAutoEncoder()
22+
23+
trainer = Trainer(
24+
max_epochs=2,
25+
callbacks=[UploadModelCallback()],
26+
)
27+
trainer.fit(
28+
autoencoder,
29+
data.DataLoader(train, batch_size=256),
30+
data.DataLoader(val, batch_size=256),
31+
)

examples/train_resume.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import torch.utils.data as data
2+
import torchvision as tv
3+
from lightning import Trainer
4+
from litmodels import download_model
5+
from sample_model import LitAutoEncoder
6+
7+
if __name__ == "__main__":
8+
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
9+
train, val = data.random_split(dataset, [55000, 5000])
10+
11+
model_path = download_model(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
12+
print(f"model: {model_path}")
13+
# autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path)
14+
15+
trainer = Trainer(
16+
max_epochs=4,
17+
)
18+
trainer.fit(
19+
LitAutoEncoder(),
20+
data.DataLoader(train, batch_size=256),
21+
data.DataLoader(val, batch_size=256),
22+
ckpt_path=model_path,
23+
)

examples/train_simple.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch.utils.data as data
2+
import torchvision as tv
3+
from lightning import Trainer
4+
from lightning.pytorch.callbacks import ModelCheckpoint
5+
from litmodels import upload_model
6+
from sample_model import LitAutoEncoder
7+
8+
if __name__ == "__main__":
9+
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
10+
train, val = data.random_split(dataset, [55000, 5000])
11+
12+
autoencoder = LitAutoEncoder()
13+
# Save the best model based on validation loss
14+
checkpoint_callback = ModelCheckpoint(
15+
monitor="val_loss", # Metric to monitor
16+
save_top_k=1, # Only save the best model (use -1 to save all)
17+
mode="min", # 'min' for loss, 'max' for accuracy
18+
save_last=True, # Additionally save the last checkpoint
19+
dirpath="my_checkpoints/", # Directory to save checkpoints
20+
filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename
21+
)
22+
23+
trainer = Trainer(
24+
max_epochs=2,
25+
callbacks=[checkpoint_callback],
26+
)
27+
trainer.fit(
28+
autoencoder,
29+
data.DataLoader(train, batch_size=256),
30+
data.DataLoader(val, batch_size=256),
31+
)
32+
print(f"last: {vars(checkpoint_callback)}")
33+
upload_model(path=checkpoint_callback.last_model_path, name="jirka/kaggle/lit-auto-encoder-simple")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ ignore-init-module-imports = true
119119
"setup.py" = ["D100", "SIM115"]
120120
"__about__.py" = ["D100"]
121121
"__init__.py" = ["D100", "E402"]
122+
"examples/**" = ["D"] # todo
122123
"tests/**" = ["D"]
123124

124125
[tool.ruff.lint.pydocstyle]

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
# NOTE: once we add more dependcies, conside update dependabot to check for updates
1+
# NOTE: once we add more dependencies, consider update dependabot to check for updates
22

33
lightning-sdk >=0.1.26

0 commit comments

Comments
 (0)