diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/sample_model.py b/examples/sample_model.py new file mode 100644 index 0000000..664a284 --- /dev/null +++ b/examples/sample_model.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn +from lightning import LightningModule +from torch.nn.functional import mse_loss + + +class LitAutoEncoder(LightningModule): + def __init__(self): + super().__init__() + self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)) + self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + return self.encoder(x) + + def training_step(self, batch, batch_idx): + # training_step defines the train loop. It is independent of forward + x, _ = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = mse_loss(x_hat, x) + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=1e-4) diff --git a/examples/train_callback.py b/examples/train_callback.py new file mode 100644 index 0000000..fe8fae9 --- /dev/null +++ b/examples/train_callback.py @@ -0,0 +1,31 @@ +import torch.utils.data as data +import torchvision as tv +from lightning import Callback, Trainer +from litmodels import upload_model +from sample_model import LitAutoEncoder + + +class UploadModelCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + # Get the best model path from the checkpoint callback + best_model_path = trainer.checkpoint_callback.best_model_path + if best_model_path: + print(f"Uploading model: {best_model_path}") + upload_model(path=best_model_path, name="jirka/kaggle/lit-auto-encoder-callback") + + +if __name__ == "__main__": + dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) + train, val = data.random_split(dataset, [55000, 5000]) + + autoencoder = LitAutoEncoder() + + trainer = Trainer( + max_epochs=2, + callbacks=[UploadModelCallback()], + ) + trainer.fit( + autoencoder, + data.DataLoader(train, batch_size=256), + data.DataLoader(val, batch_size=256), + ) diff --git a/examples/train_resume.py b/examples/train_resume.py new file mode 100644 index 0000000..0ab1176 --- /dev/null +++ b/examples/train_resume.py @@ -0,0 +1,23 @@ +import torch.utils.data as data +import torchvision as tv +from lightning import Trainer +from litmodels import download_model +from sample_model import LitAutoEncoder + +if __name__ == "__main__": + dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) + train, val = data.random_split(dataset, [55000, 5000]) + + model_path = download_model(name="jirka/kaggle/lit-auto-encoder-simple", download_dir="my_models") + print(f"model: {model_path}") + # autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path) + + trainer = Trainer( + max_epochs=4, + ) + trainer.fit( + LitAutoEncoder(), + data.DataLoader(train, batch_size=256), + data.DataLoader(val, batch_size=256), + ckpt_path=model_path, + ) diff --git a/examples/train_simple.py b/examples/train_simple.py new file mode 100644 index 0000000..489d832 --- /dev/null +++ b/examples/train_simple.py @@ -0,0 +1,33 @@ +import torch.utils.data as data +import torchvision as tv +from lightning import Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from litmodels import upload_model +from sample_model import LitAutoEncoder + +if __name__ == "__main__": + dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) + train, val = data.random_split(dataset, [55000, 5000]) + + autoencoder = LitAutoEncoder() + # Save the best model based on validation loss + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", # Metric to monitor + save_top_k=1, # Only save the best model (use -1 to save all) + mode="min", # 'min' for loss, 'max' for accuracy + save_last=True, # Additionally save the last checkpoint + dirpath="my_checkpoints/", # Directory to save checkpoints + filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename + ) + + trainer = Trainer( + max_epochs=2, + callbacks=[checkpoint_callback], + ) + trainer.fit( + autoencoder, + data.DataLoader(train, batch_size=256), + data.DataLoader(val, batch_size=256), + ) + print(f"last: {vars(checkpoint_callback)}") + upload_model(path=checkpoint_callback.last_model_path, name="jirka/kaggle/lit-auto-encoder-simple") diff --git a/pyproject.toml b/pyproject.toml index 7011c49..601c859 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ ignore-init-module-imports = true "setup.py" = ["D100", "SIM115"] "__about__.py" = ["D100"] "__init__.py" = ["D100", "E402"] +"examples/**" = ["D"] # todo "tests/**" = ["D"] [tool.ruff.lint.pydocstyle] diff --git a/requirements.txt b/requirements.txt index 3f174f7..a137bd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -# NOTE: once we add more dependcies, conside update dependabot to check for updates +# NOTE: once we add more dependencies, consider update dependabot to check for updates lightning-sdk >=0.1.26