Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ print(f"last: {vars(checkpoint_callback)}")
upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME)
```

To load the model, use the `load_model` function.
To load the model, use the `download_model` function.

```python
from lightning import Trainer
Expand Down Expand Up @@ -107,3 +107,68 @@ trainer.fit(
ckpt_path=model_path,
)
```

You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage.

```python
import os
import lightning as L
from psutil import cpu_count
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from litlogger import LightningLogger


class LitAutoEncoder(L.LightningModule):

def __init__(self, lr=1e-3, inp_size=28):
super().__init__()

self.encoder = nn.Sequential(
nn.Linear(inp_size * inp_size, 64), nn.ReLU(), nn.Linear(64, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, inp_size * inp_size)
)
self.lr = lr
self.save_hyperparameters()

def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = nn.functional.mse_loss(x_hat, x)
# log metrics
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.lr)
return optimizer


if __name__ == "__main__":
# init the autoencoder
autoencoder = LitAutoEncoder(lr=1e-3, inp_size=28)

# setup data
train_loader = DataLoader(
dataset=MNIST(os.getcwd(), download=True, transform=ToTensor()),
batch_size=32,
shuffle=True,
num_workers=cpu_count(),
persistent_workers=True,
)

# configure the logger
lit_logger = LightningLogger(log_model=True)

# pass logger to the Trainer
trainer = L.Trainer(max_epochs=5, logger=lit_logger)

# train the model
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
```
13 changes: 5 additions & 8 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,10 @@ def find_source():
# only run doctests marked with a ".. doctest::" directive
doctest_test_doctest_blocks = ""
doctest_global_setup = """

import importlib
import os
import torch

import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule

"""
coverage_skip_undoc_in_source = True

linkcheck_ignore = [
# ignore the following URLs
"https://github.com/gridai/lit-logger",
]
Loading