Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ from lightning import Trainer
from litmodels import download_model
from litmodels.demos import BoringModel


# Define the model name - this should be unique to your model
# The format is <organization>/<teamspace>/<model-name>:<model-version>
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest"
Expand All @@ -92,6 +91,54 @@ trainer = Trainer(max_epochs=4)
trainer.fit(LitModel(), ckpt_path=checkpoint_path)
```

You can also enhance your training with a simple Checkpointing callback which would always save the best model to the cloud storage and continue training.
This can would be handy especially with long trainings or using interruptible machines so you would always resume/recover from the best model.

```python
import os
import torch.utils.data as data
import torchvision as tv
from lightning import Callback, Trainer
from litmodels import upload_model
from litmodels.demos import BoringModel

# Define the model name - this should be unique to your model
# The format is <organization>/<teamspace>/<model-name>
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback"


class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
# logging the computed loss
self.log("train_loss", loss)
return {"loss": loss}


class UploadModelCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Get the best model path from the checkpoint callback
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
if checkpoint_path and os.path.exists(checkpoint_path):
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)


dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

trainer = Trainer(
max_epochs=2,
callbacks=[UploadModelCallback()],
)
trainer.fit(
LitModel(),
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
)
```

## Logging Models

You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage.

```python
Expand Down
Loading