Skip to content

Commit e30bbc3

Browse files
committed
add logger example
1 parent 20effb2 commit e30bbc3

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

README.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,68 @@ trainer.fit(
107107
ckpt_path=model_path,
108108
)
109109
```
110+
111+
You can also use model store sto together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage.
112+
113+
```python
114+
import os
115+
import lightning as L
116+
from psutil import cpu_count
117+
from torch import optim, nn
118+
from torch.utils.data import DataLoader
119+
from torchvision.datasets import MNIST
120+
from torchvision.transforms import ToTensor
121+
from litlogger import LightningLogger
122+
123+
124+
class LitAutoEncoder(L.LightningModule):
125+
126+
def __init__(self, lr=1e-3, inp_size=28):
127+
super().__init__()
128+
129+
self.encoder = nn.Sequential(
130+
nn.Linear(inp_size * inp_size, 64), nn.ReLU(), nn.Linear(64, 3)
131+
)
132+
self.decoder = nn.Sequential(
133+
nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, inp_size * inp_size)
134+
)
135+
self.lr = lr
136+
self.save_hyperparameters()
137+
138+
def training_step(self, batch, batch_idx):
139+
x, y = batch
140+
x = x.view(x.size(0), -1)
141+
z = self.encoder(x)
142+
x_hat = self.decoder(z)
143+
loss = nn.functional.mse_loss(x_hat, x)
144+
# log metrics
145+
self.log("train_loss", loss)
146+
return loss
147+
148+
def configure_optimizers(self):
149+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
150+
return optimizer
151+
152+
153+
if __name__ == "__main__":
154+
# init the autoencoder
155+
autoencoder = LitAutoEncoder(lr=1e-3, inp_size=28)
156+
157+
# setup data
158+
train_loader = DataLoader(
159+
dataset=MNIST(os.getcwd(), download=True, transform=ToTensor()),
160+
batch_size=32,
161+
shuffle=True,
162+
num_workers=cpu_count(),
163+
persistent_workers=True,
164+
)
165+
166+
# configure the logger
167+
lit_logger = LightningLogger(log_model=True)
168+
169+
# pass logger to the Trainer
170+
trainer = L.Trainer(max_epochs=5, logger=lit_logger)
171+
172+
# train the model
173+
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
174+
```

0 commit comments

Comments
 (0)