diff --git a/README.md b/README.md
index 72556b7..0589bb9 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-
+from litmodels.integrations import LightningModelCheckpoint
# Effortless Model Management for Your Development ⚡
@@ -102,19 +102,10 @@ from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "//"
-
-class LitModel(BoringModel):
- def training_step(self, batch, batch_idx: int):
- loss = self.step(batch)
- # logging the computed loss
- self.log("train_loss", loss)
- return {"loss": loss}
-
-
# Configure Lightning Trainer
trainer = Trainer(max_epochs=2)
# Define the model and train it
-trainer.fit(LitModel())
+trainer.fit(BoringModel())
# Upload the best model to cloud storage
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
@@ -131,67 +122,39 @@ from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "//:"
-
-class LitModel(BoringModel):
- def training_step(self, batch, batch_idx: int):
- loss = self.step(batch)
- # logging the computed loss
- self.log("train_loss", loss)
- return {"loss": loss}
-
-
# Load the model from cloud storage
checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
print(f"model: {checkpoint_path}")
# Train the model with extended training period
trainer = Trainer(max_epochs=4)
-trainer.fit(LitModel(), ckpt_path=checkpoint_path)
+trainer.fit(BoringModel(), ckpt_path=checkpoint_path)
```
- Advanced Checkpointing Workflow
+ Checkpointing Workflow with Lightning
-Enhance your training process with an automatic checkpointing callback that uploads the best model at the end of each epoch.
-While the example uses PyTorch Lightning callbacks, similar workflows can be implemented in any training loop that produces checkpoints.
+Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch.
```python
-import os
import torch.utils.data as data
import torchvision as tv
-from lightning import Callback, Trainer
-from litmodels import upload_model
+from lightning import Trainer
+from litmodels.integrations import LightningModelCheckpoint
from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "//"
-
-class LitModel(BoringModel):
- def training_step(self, batch, batch_idx: int):
- loss = self.step(batch)
- # logging the computed loss
- self.log("train_loss", loss)
- return {"loss": loss}
-
-
-class UploadModelCallback(Callback):
- def on_train_epoch_end(self, trainer, pl_module):
- # Get the best model path from the checkpoint callback
- checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
- if checkpoint_path and os.path.exists(checkpoint_path):
- upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
-
-
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])
trainer = Trainer(
max_epochs=2,
- callbacks=[UploadModelCallback()],
+ callbacks=[LightningModelCheckpoint(model_name=MY_MODEL_NAME)],
)
trainer.fit(
- LitModel(),
+ BoringModel(),
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
)