diff --git a/README.md b/README.md index e0249fb..3614e97 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,6 @@ from lightning import Trainer from litmodels import download_model from litmodels.demos import BoringModel - # Define the model name - this should be unique to your model # The format is //: MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest" @@ -92,6 +91,54 @@ trainer = Trainer(max_epochs=4) trainer.fit(LitModel(), ckpt_path=checkpoint_path) ``` +You can also enhance your training with a simple Checkpointing callback which would always save the best model to the cloud storage and continue training. +This can would be handy especially with long trainings or using interruptible machines so you would always resume/recover from the best model. + +```python +import os +import torch.utils.data as data +import torchvision as tv +from lightning import Callback, Trainer +from litmodels import upload_model +from litmodels.demos import BoringModel + +# Define the model name - this should be unique to your model +# The format is // +MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback" + + +class LitModel(BoringModel): + def training_step(self, batch, batch_idx: int): + loss = self.step(batch) + # logging the computed loss + self.log("train_loss", loss) + return {"loss": loss} + + +class UploadModelCallback(Callback): + def on_train_epoch_end(self, trainer, pl_module): + # Get the best model path from the checkpoint callback + checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") + if checkpoint_path and os.path.exists(checkpoint_path): + upload_model(model=checkpoint_path, name=MY_MODEL_NAME) + + +dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) +train, val = data.random_split(dataset, [55000, 5000]) + +trainer = Trainer( + max_epochs=2, + callbacks=[UploadModelCallback()], +) +trainer.fit( + LitModel(), + data.DataLoader(train, batch_size=256), + data.DataLoader(val, batch_size=256), +) +``` + +## Logging Models + You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage. ```python