Skip to content

Commit 4ecb9a7

Browse files
authored
readme: add callback example (#24)
1 parent 9601748 commit 4ecb9a7

File tree

1 file changed

+48
-1
lines changed

1 file changed

+48
-1
lines changed

README.md

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ from lightning import Trainer
6969
from litmodels import download_model
7070
from litmodels.demos import BoringModel
7171

72-
7372
# Define the model name - this should be unique to your model
7473
# The format is <organization>/<teamspace>/<model-name>:<model-version>
7574
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest"
@@ -92,6 +91,54 @@ trainer = Trainer(max_epochs=4)
9291
trainer.fit(LitModel(), ckpt_path=checkpoint_path)
9392
```
9493

94+
You can also enhance your training with a simple Checkpointing callback which would always save the best model to the cloud storage and continue training.
95+
This can would be handy especially with long trainings or using interruptible machines so you would always resume/recover from the best model.
96+
97+
```python
98+
import os
99+
import torch.utils.data as data
100+
import torchvision as tv
101+
from lightning import Callback, Trainer
102+
from litmodels import upload_model
103+
from litmodels.demos import BoringModel
104+
105+
# Define the model name - this should be unique to your model
106+
# The format is <organization>/<teamspace>/<model-name>
107+
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback"
108+
109+
110+
class LitModel(BoringModel):
111+
def training_step(self, batch, batch_idx: int):
112+
loss = self.step(batch)
113+
# logging the computed loss
114+
self.log("train_loss", loss)
115+
return {"loss": loss}
116+
117+
118+
class UploadModelCallback(Callback):
119+
def on_train_epoch_end(self, trainer, pl_module):
120+
# Get the best model path from the checkpoint callback
121+
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
122+
if checkpoint_path and os.path.exists(checkpoint_path):
123+
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
124+
125+
126+
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
127+
train, val = data.random_split(dataset, [55000, 5000])
128+
129+
trainer = Trainer(
130+
max_epochs=2,
131+
callbacks=[UploadModelCallback()],
132+
)
133+
trainer.fit(
134+
LitModel(),
135+
data.DataLoader(train, batch_size=256),
136+
data.DataLoader(val, batch_size=256),
137+
)
138+
```
139+
140+
## Logging Models
141+
95142
You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage.
96143

97144
```python

0 commit comments

Comments
 (0)