1- <div align =' center ' >
1+ from litmodels.integrations import LightningModelCheckpoint <div align =' center ' >
22
33# Effortless Model Management for Your Development ⚡
44
@@ -102,19 +102,10 @@ from litmodels.demos import BoringModel
102102# Define the model name - this should be unique to your model
103103MY_MODEL_NAME = " <organization>/<teamspace>/<model-name>"
104104
105-
106- class LitModel (BoringModel ):
107- def training_step (self , batch , batch_idx : int ):
108- loss = self .step(batch)
109- # logging the computed loss
110- self .log(" train_loss" , loss)
111- return {" loss" : loss}
112-
113-
114105# Configure Lightning Trainer
115106trainer = Trainer(max_epochs = 2 )
116107# Define the model and train it
117- trainer.fit(LitModel ())
108+ trainer.fit(BoringModel ())
118109
119110# Upload the best model to cloud storage
120111checkpoint_path = getattr (trainer.checkpoint_callback, " best_model_path" )
@@ -131,67 +122,39 @@ from litmodels.demos import BoringModel
131122# Define the model name - this should be unique to your model
132123MY_MODEL_NAME = " <organization>/<teamspace>/<model-name>:<model-version>"
133124
134-
135- class LitModel (BoringModel ):
136- def training_step (self , batch , batch_idx : int ):
137- loss = self .step(batch)
138- # logging the computed loss
139- self .log(" train_loss" , loss)
140- return {" loss" : loss}
141-
142-
143125# Load the model from cloud storage
144126checkpoint_path = download_model(name = MY_MODEL_NAME , download_dir = " my_models" )
145127print (f " model: { checkpoint_path} " )
146128
147129# Train the model with extended training period
148130trainer = Trainer(max_epochs = 4 )
149- trainer.fit(LitModel (), ckpt_path = checkpoint_path)
131+ trainer.fit(BoringModel (), ckpt_path = checkpoint_path)
150132```
151133
152134<details >
153- <summary>Advanced Checkpointing Workflow</summary>
135+ <summary>Checkpointing Workflow with Lightning </summary>
154136
155- Enhance your training process with an automatic checkpointing callback that uploads the best model at the end of each epoch.
156- While the example uses PyTorch Lightning callbacks, similar workflows can be implemented in any training loop that produces checkpoints.
137+ Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch.
157138
158139``` python
159- import os
160140import torch.utils.data as data
161141import torchvision as tv
162- from lightning import Callback, Trainer
163- from litmodels import upload_model
142+ from lightning import Trainer
143+ from litmodels.integrations import LightningModelCheckpoint
164144from litmodels.demos import BoringModel
165145
166146# Define the model name - this should be unique to your model
167147MY_MODEL_NAME = " <organization>/<teamspace>/<model-name>"
168148
169-
170- class LitModel (BoringModel ):
171- def training_step (self , batch , batch_idx : int ):
172- loss = self .step(batch)
173- # logging the computed loss
174- self .log(" train_loss" , loss)
175- return {" loss" : loss}
176-
177-
178- class UploadModelCallback (Callback ):
179- def on_train_epoch_end (self , trainer , pl_module ):
180- # Get the best model path from the checkpoint callback
181- checkpoint_path = getattr (trainer.checkpoint_callback, " best_model_path" )
182- if checkpoint_path and os.path.exists(checkpoint_path):
183- upload_model(model = checkpoint_path, name = MY_MODEL_NAME )
184-
185-
186149dataset = tv.datasets.MNIST(" ." , download = True , transform = tv.transforms.ToTensor())
187150train, val = data.random_split(dataset, [55000 , 5000 ])
188151
189152trainer = Trainer(
190153 max_epochs = 2 ,
191- callbacks = [UploadModelCallback( )],
154+ callbacks = [LightningModelCheckpoint( model_name = MY_MODEL_NAME )],
192155)
193156trainer.fit(
194- LitModel (),
157+ BoringModel (),
195158 data.DataLoader(train, batch_size = 256 ),
196159 data.DataLoader(val, batch_size = 256 ),
197160)
0 commit comments