diff --git a/README.md b/README.md index cc21f41..0cf48b0 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,7 @@ print(f"last: {vars(checkpoint_callback)}") upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME) ``` -To load the model, use the `load_model` function. +To load the model, use the `download_model` function. ```python from lightning import Trainer @@ -107,3 +107,68 @@ trainer.fit( ckpt_path=model_path, ) ``` + +You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage. + +```python +import os +import lightning as L +from psutil import cpu_count +from torch import optim, nn +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST +from torchvision.transforms import ToTensor +from litlogger import LightningLogger + + +class LitAutoEncoder(L.LightningModule): + + def __init__(self, lr=1e-3, inp_size=28): + super().__init__() + + self.encoder = nn.Sequential( + nn.Linear(inp_size * inp_size, 64), nn.ReLU(), nn.Linear(64, 3) + ) + self.decoder = nn.Sequential( + nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, inp_size * inp_size) + ) + self.lr = lr + self.save_hyperparameters() + + def training_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = nn.functional.mse_loss(x_hat, x) + # log metrics + self.log("train_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + +if __name__ == "__main__": + # init the autoencoder + autoencoder = LitAutoEncoder(lr=1e-3, inp_size=28) + + # setup data + train_loader = DataLoader( + dataset=MNIST(os.getcwd(), download=True, transform=ToTensor()), + batch_size=32, + shuffle=True, + num_workers=cpu_count(), + persistent_workers=True, + ) + + # configure the logger + lit_logger = LightningLogger(log_model=True) + + # pass logger to the Trainer + trainer = L.Trainer(max_epochs=5, logger=lit_logger) + + # train the model + trainer.fit(model=autoencoder, train_dataloaders=train_loader) +``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 32d72e9..d3b6c4c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -386,13 +386,10 @@ def find_source(): # only run doctests marked with a ".. doctest::" directive doctest_test_doctest_blocks = "" doctest_global_setup = """ - -import importlib -import os -import torch - -import pytorch_lightning as pl -from pytorch_lightning import Trainer, LightningModule - """ coverage_skip_undoc_in_source = True + +linkcheck_ignore = [ + # ignore the following URLs + "https://github.com/gridai/lit-logger", +]