Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added examples/__init__.py
Empty file.
28 changes: 28 additions & 0 deletions examples/sample_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import torch.nn as nn
from lightning import LightningModule
from torch.nn.functional import mse_loss


class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

def forward(self, x):
# in lightning, forward defines the prediction/inference actions
return self.encoder(x)

def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-4)
31 changes: 31 additions & 0 deletions examples/train_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch.utils.data as data
import torchvision as tv
from lightning import Callback, Trainer
from litmodels import upload_model
from sample_model import LitAutoEncoder


class UploadModelCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Get the best model path from the checkpoint callback
best_model_path = trainer.checkpoint_callback.best_model_path
if best_model_path:
print(f"Uploading model: {best_model_path}")
upload_model(path=best_model_path, name="jirka/kaggle/lit-auto-encoder-callback")


if __name__ == "__main__":
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()

trainer = Trainer(
max_epochs=2,
callbacks=[UploadModelCallback()],
)
trainer.fit(
autoencoder,
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
)
23 changes: 23 additions & 0 deletions examples/train_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch.utils.data as data
import torchvision as tv
from lightning import Trainer
from litmodels import download_model
from sample_model import LitAutoEncoder

if __name__ == "__main__":
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

model_path = download_model(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models")
print(f"model: {model_path}")
# autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path)

trainer = Trainer(
max_epochs=4,
)
trainer.fit(
LitAutoEncoder(),
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
ckpt_path=model_path,
)
33 changes: 33 additions & 0 deletions examples/train_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch.utils.data as data
import torchvision as tv
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from litmodels import upload_model
from sample_model import LitAutoEncoder

if __name__ == "__main__":
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
# Save the best model based on validation loss
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", # Metric to monitor
save_top_k=1, # Only save the best model (use -1 to save all)
mode="min", # 'min' for loss, 'max' for accuracy
save_last=True, # Additionally save the last checkpoint
dirpath="my_checkpoints/", # Directory to save checkpoints
filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename
)

trainer = Trainer(
max_epochs=2,
callbacks=[checkpoint_callback],
)
trainer.fit(
autoencoder,
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
)
print(f"last: {vars(checkpoint_callback)}")
upload_model(path=checkpoint_callback.last_model_path, name="jirka/kaggle/lit-auto-encoder-simple")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ ignore-init-module-imports = true
"setup.py" = ["D100", "SIM115"]
"__about__.py" = ["D100"]
"__init__.py" = ["D100", "E402"]
"examples/**" = ["D"] # todo
"tests/**" = ["D"]

[tool.ruff.lint.pydocstyle]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# NOTE: once we add more dependcies, conside update dependabot to check for updates
# NOTE: once we add more dependencies, consider update dependabot to check for updates

lightning-sdk >=0.1.26
Loading