Skip to content

Commit 36d8953

Browse files
committed
simplify examples in readme
1 parent 2a29b74 commit 36d8953

File tree

1 file changed

+9
-46
lines changed

1 file changed

+9
-46
lines changed

README.md

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<div align='center'>
1+
from litmodels.integrations import LightningModelCheckpoint<div align='center'>
22

33
# Effortless Model Management for Your Development ⚡
44

@@ -102,19 +102,10 @@ from litmodels.demos import BoringModel
102102
# Define the model name - this should be unique to your model
103103
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"
104104

105-
106-
class LitModel(BoringModel):
107-
def training_step(self, batch, batch_idx: int):
108-
loss = self.step(batch)
109-
# logging the computed loss
110-
self.log("train_loss", loss)
111-
return {"loss": loss}
112-
113-
114105
# Configure Lightning Trainer
115106
trainer = Trainer(max_epochs=2)
116107
# Define the model and train it
117-
trainer.fit(LitModel())
108+
trainer.fit(BoringModel())
118109

119110
# Upload the best model to cloud storage
120111
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
@@ -131,67 +122,39 @@ from litmodels.demos import BoringModel
131122
# Define the model name - this should be unique to your model
132123
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>:<model-version>"
133124

134-
135-
class LitModel(BoringModel):
136-
def training_step(self, batch, batch_idx: int):
137-
loss = self.step(batch)
138-
# logging the computed loss
139-
self.log("train_loss", loss)
140-
return {"loss": loss}
141-
142-
143125
# Load the model from cloud storage
144126
checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
145127
print(f"model: {checkpoint_path}")
146128

147129
# Train the model with extended training period
148130
trainer = Trainer(max_epochs=4)
149-
trainer.fit(LitModel(), ckpt_path=checkpoint_path)
131+
trainer.fit(BoringModel(), ckpt_path=checkpoint_path)
150132
```
151133

152134
<details>
153-
<summary>Advanced Checkpointing Workflow</summary>
135+
<summary>Checkpointing Workflow with Lightning</summary>
154136

155-
Enhance your training process with an automatic checkpointing callback that uploads the best model at the end of each epoch.
156-
While the example uses PyTorch Lightning callbacks, similar workflows can be implemented in any training loop that produces checkpoints.
137+
Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch.
157138

158139
```python
159-
import os
160140
import torch.utils.data as data
161141
import torchvision as tv
162-
from lightning import Callback, Trainer
163-
from litmodels import upload_model
142+
from lightning import Trainer
143+
from litmodels.integrations import LightningModelCheckpoint
164144
from litmodels.demos import BoringModel
165145

166146
# Define the model name - this should be unique to your model
167147
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"
168148

169-
170-
class LitModel(BoringModel):
171-
def training_step(self, batch, batch_idx: int):
172-
loss = self.step(batch)
173-
# logging the computed loss
174-
self.log("train_loss", loss)
175-
return {"loss": loss}
176-
177-
178-
class UploadModelCallback(Callback):
179-
def on_train_epoch_end(self, trainer, pl_module):
180-
# Get the best model path from the checkpoint callback
181-
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
182-
if checkpoint_path and os.path.exists(checkpoint_path):
183-
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
184-
185-
186149
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
187150
train, val = data.random_split(dataset, [55000, 5000])
188151

189152
trainer = Trainer(
190153
max_epochs=2,
191-
callbacks=[UploadModelCallback()],
154+
callbacks=[LightningModelCheckpoint(model_name=MY_MODEL_NAME)],
192155
)
193156
trainer.fit(
194-
LitModel(),
157+
BoringModel(),
195158
data.DataLoader(train, batch_size=256),
196159
data.DataLoader(val, batch_size=256),
197160
)

0 commit comments

Comments
 (0)