@@ -76,7 +76,7 @@ print(f"last: {vars(checkpoint_callback)}")
7676upload_model(model = checkpoint_callback.last_model_path, name = MY_MODEL_NAME )
7777```
7878
79- To load the model, use the ` load_model ` function.
79+ To load the model, use the ` download_model ` function.
8080
8181``` python
8282from lightning import Trainer
@@ -107,3 +107,68 @@ trainer.fit(
107107 ckpt_path = model_path,
108108)
109109```
110+
111+ You can also use model store together with [ LitLogger] ( https://github.com/gridai/lit-logger ) to log your model to the cloud storage.
112+
113+ ``` python
114+ import os
115+ import lightning as L
116+ from psutil import cpu_count
117+ from torch import optim, nn
118+ from torch.utils.data import DataLoader
119+ from torchvision.datasets import MNIST
120+ from torchvision.transforms import ToTensor
121+ from litlogger import LightningLogger
122+
123+
124+ class LitAutoEncoder (L .LightningModule ):
125+
126+ def __init__ (self , lr = 1e-3 , inp_size = 28 ):
127+ super ().__init__ ()
128+
129+ self .encoder = nn.Sequential(
130+ nn.Linear(inp_size * inp_size, 64 ), nn.ReLU(), nn.Linear(64 , 3 )
131+ )
132+ self .decoder = nn.Sequential(
133+ nn.Linear(3 , 64 ), nn.ReLU(), nn.Linear(64 , inp_size * inp_size)
134+ )
135+ self .lr = lr
136+ self .save_hyperparameters()
137+
138+ def training_step (self , batch , batch_idx ):
139+ x, y = batch
140+ x = x.view(x.size(0 ), - 1 )
141+ z = self .encoder(x)
142+ x_hat = self .decoder(z)
143+ loss = nn.functional.mse_loss(x_hat, x)
144+ # log metrics
145+ self .log(" train_loss" , loss)
146+ return loss
147+
148+ def configure_optimizers (self ):
149+ optimizer = optim.Adam(self .parameters(), lr = self .lr)
150+ return optimizer
151+
152+
153+ if __name__ == " __main__" :
154+ # init the autoencoder
155+ autoencoder = LitAutoEncoder(lr = 1e-3 , inp_size = 28 )
156+
157+ # setup data
158+ train_loader = DataLoader(
159+ dataset = MNIST(os.getcwd(), download = True , transform = ToTensor()),
160+ batch_size = 32 ,
161+ shuffle = True ,
162+ num_workers = cpu_count(),
163+ persistent_workers = True ,
164+ )
165+
166+ # configure the logger
167+ lit_logger = LightningLogger(log_model = True )
168+
169+ # pass logger to the Trainer
170+ trainer = L.Trainer(max_epochs = 5 , logger = lit_logger)
171+
172+ # train the model
173+ trainer.fit(model = autoencoder, train_dataloaders = train_loader)
174+ ```
0 commit comments