Skip to content

Commit 59ef9b7

Browse files
add logger example (#20)
Co-authored-by: Andrei-Aksionov <[email protected]>
1 parent fc9c21e commit 59ef9b7

File tree

2 files changed

+71
-9
lines changed

2 files changed

+71
-9
lines changed

README.md

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ print(f"last: {vars(checkpoint_callback)}")
7676
upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME)
7777
```
7878

79-
To load the model, use the `load_model` function.
79+
To load the model, use the `download_model` function.
8080

8181
```python
8282
from lightning import Trainer
@@ -107,3 +107,68 @@ trainer.fit(
107107
ckpt_path=model_path,
108108
)
109109
```
110+
111+
You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage.
112+
113+
```python
114+
import os
115+
import lightning as L
116+
from psutil import cpu_count
117+
from torch import optim, nn
118+
from torch.utils.data import DataLoader
119+
from torchvision.datasets import MNIST
120+
from torchvision.transforms import ToTensor
121+
from litlogger import LightningLogger
122+
123+
124+
class LitAutoEncoder(L.LightningModule):
125+
126+
def __init__(self, lr=1e-3, inp_size=28):
127+
super().__init__()
128+
129+
self.encoder = nn.Sequential(
130+
nn.Linear(inp_size * inp_size, 64), nn.ReLU(), nn.Linear(64, 3)
131+
)
132+
self.decoder = nn.Sequential(
133+
nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, inp_size * inp_size)
134+
)
135+
self.lr = lr
136+
self.save_hyperparameters()
137+
138+
def training_step(self, batch, batch_idx):
139+
x, y = batch
140+
x = x.view(x.size(0), -1)
141+
z = self.encoder(x)
142+
x_hat = self.decoder(z)
143+
loss = nn.functional.mse_loss(x_hat, x)
144+
# log metrics
145+
self.log("train_loss", loss)
146+
return loss
147+
148+
def configure_optimizers(self):
149+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
150+
return optimizer
151+
152+
153+
if __name__ == "__main__":
154+
# init the autoencoder
155+
autoencoder = LitAutoEncoder(lr=1e-3, inp_size=28)
156+
157+
# setup data
158+
train_loader = DataLoader(
159+
dataset=MNIST(os.getcwd(), download=True, transform=ToTensor()),
160+
batch_size=32,
161+
shuffle=True,
162+
num_workers=cpu_count(),
163+
persistent_workers=True,
164+
)
165+
166+
# configure the logger
167+
lit_logger = LightningLogger(log_model=True)
168+
169+
# pass logger to the Trainer
170+
trainer = L.Trainer(max_epochs=5, logger=lit_logger)
171+
172+
# train the model
173+
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
174+
```

docs/source/conf.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,10 @@ def find_source():
386386
# only run doctests marked with a ".. doctest::" directive
387387
doctest_test_doctest_blocks = ""
388388
doctest_global_setup = """
389-
390-
import importlib
391-
import os
392-
import torch
393-
394-
import pytorch_lightning as pl
395-
from pytorch_lightning import Trainer, LightningModule
396-
397389
"""
398390
coverage_skip_undoc_in_source = True
391+
392+
linkcheck_ignore = [
393+
# ignore the following URLs
394+
"https://github.com/gridai/lit-logger",
395+
]

0 commit comments

Comments
 (0)