Skip to content

Commit b8b2b7c

Browse files
committed
update examoles with BoringModel
1 parent bd7681c commit b8b2b7c

File tree

5 files changed

+16
-109
lines changed

5 files changed

+16
-109
lines changed

examples/resume-lightning-training.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,22 @@
22
This example demonstrates how to resume training of a model using the `download_model` function.
33
"""
44

5-
import torch.utils.data as data
6-
import torchvision as tv
75
from lightning import Trainer
6+
from lightning.pytorch.demos.boring_classes import BoringModel
87
from litmodels import download_model
9-
from sample_model import LitAutoEncoder
108

119
# Define the model name - this should be unique to your model
1210
# The format is <organization>/<teamspace>/<model-name>:<model-version>
1311
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback:latest"
1412

1513

1614
if __name__ == "__main__":
17-
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
18-
train, val = data.random_split(dataset, [55000, 5000])
19-
2015
model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
2116
print(f"model: {model_path}")
2217
# autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path)
2318

24-
trainer = Trainer(
25-
max_epochs=4,
26-
)
19+
trainer = Trainer(max_epochs=4)
2720
trainer.fit(
28-
LitAutoEncoder(),
29-
data.DataLoader(train, batch_size=256),
30-
data.DataLoader(val, batch_size=256),
21+
BoringModel(),
3122
ckpt_path=model_path,
3223
)

examples/sample_model.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

examples/train-model-and-simple-save.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,18 @@
22
This example demonstrates how to train a model and upload it to the cloud using the `upload_model` function.
33
"""
44

5-
import torch.utils.data as data
6-
import torchvision as tv
75
from lightning import Trainer
6+
from lightning.pytorch.demos.boring_classes import BoringModel
87
from litmodels import upload_model
9-
from sample_model import LitAutoEncoder
108

119
# Define the model name - this should be unique to your model
1210
# The format is <organization>/<teamspace>/<model-name>
1311
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-simple"
1412

1513

1614
if __name__ == "__main__":
17-
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
18-
train, val = data.random_split(dataset, [55000, 5000])
19-
20-
autoencoder = LitAutoEncoder()
21-
2215
trainer = Trainer(max_epochs=2)
23-
trainer.fit(
24-
autoencoder,
25-
data.DataLoader(train, batch_size=256),
26-
data.DataLoader(val, batch_size=256),
27-
)
16+
trainer.fit(BoringModel())
2817
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
2918
print(f"best: {checkpoint_path}")
3019
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)

examples/train-model-with-lightning-callback.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,18 @@
22
Train a model with a Lightning callback that uploads the best model to the cloud after each epoch.
33
"""
44

5-
import torch.utils.data as data
6-
import torchvision as tv
75
from lightning import Trainer
8-
from litmodels.integrations import LitModelCheckpoint
9-
from sample_model import LitAutoEncoder
6+
from lightning.pytorch.demos.boring_classes import BoringModel
7+
from litmodels.integrations import LightningModelCheckpoint
108

119
# Define the model name - this should be unique to your model
1210
# The format is <organization>/<teamspace>/<model-name>
1311
MY_MODEL_NAME = "lightning-ai/jirka/lit-auto-encoder-callback"
1412

1513

1614
if __name__ == "__main__":
17-
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
18-
train, val = data.random_split(dataset, [55000, 5000])
19-
20-
autoencoder = LitAutoEncoder()
21-
2215
trainer = Trainer(
2316
max_epochs=2,
24-
callbacks=LitModelCheckpoint(model_name=MY_MODEL_NAME),
25-
)
26-
trainer.fit(
27-
autoencoder,
28-
data.DataLoader(train, batch_size=256),
29-
data.DataLoader(val, batch_size=256),
17+
callbacks=LightningModelCheckpoint(model_name=MY_MODEL_NAME),
3018
)
19+
trainer.fit(BoringModel())

examples/train-model-with-lightning-logger.py

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,58 +7,24 @@
77
88
"""
99

10-
import os
11-
12-
from lightning import LightningModule, Trainer
10+
from lightning import Trainer
11+
from lightning.pytorch.demos.boring_classes import BoringModel
1312
from litlogger import LightningLogger
14-
from psutil import cpu_count
15-
from torch import nn, optim
16-
from torch.utils.data import DataLoader
17-
from torchvision.datasets import MNIST
18-
from torchvision.transforms import ToTensor
19-
20-
21-
class LitAutoEncoder(LightningModule):
22-
def __init__(self, lr=1e-3, inp_size=28):
23-
super().__init__()
2413

25-
self.encoder = nn.Sequential(nn.Linear(inp_size * inp_size, 64), nn.ReLU(), nn.Linear(64, 3))
26-
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, inp_size * inp_size))
27-
self.lr = lr
28-
self.save_hyperparameters()
2914

15+
class DemoModel(BoringModel):
3016
def training_step(self, batch, batch_idx):
31-
x, y = batch
32-
x = x.view(x.size(0), -1)
33-
z = self.encoder(x)
34-
x_hat = self.decoder(z)
35-
loss = nn.functional.mse_loss(x_hat, x)
36-
# log metrics
37-
self.log("train_loss", loss)
38-
return loss
39-
40-
def configure_optimizers(self):
41-
return optim.Adam(self.parameters(), lr=self.lr)
17+
output = super().training_step(batch, batch_idx)
18+
self.log("train_loss", output["loss"])
19+
return output
4220

4321

4422
if __name__ == "__main__":
45-
# init the autoencoder
46-
autoencoder = LitAutoEncoder(lr=1e-3, inp_size=28)
47-
48-
# setup data
49-
train_loader = DataLoader(
50-
dataset=MNIST(os.getcwd(), download=True, transform=ToTensor()),
51-
batch_size=32,
52-
shuffle=True,
53-
num_workers=cpu_count(),
54-
persistent_workers=True,
55-
)
56-
5723
# configure the logger
5824
lit_logger = LightningLogger(log_model=True)
5925

6026
# pass logger to the Trainer
6127
trainer = Trainer(max_epochs=5, logger=lit_logger)
6228

6329
# train the model
64-
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
30+
trainer.fit(model=DemoModel())

0 commit comments

Comments
 (0)