Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 9 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<div align='center'>
from litmodels.integrations import LightningModelCheckpoint<div align='center'>

# Effortless Model Management for Your Development ⚡

Expand Down Expand Up @@ -102,19 +102,10 @@ from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"


class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
# logging the computed loss
self.log("train_loss", loss)
return {"loss": loss}


# Configure Lightning Trainer
trainer = Trainer(max_epochs=2)
# Define the model and train it
trainer.fit(LitModel())
trainer.fit(BoringModel())

# Upload the best model to cloud storage
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
Expand All @@ -131,67 +122,39 @@ from litmodels.demos import BoringModel
# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>:<model-version>"


class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
# logging the computed loss
self.log("train_loss", loss)
return {"loss": loss}


# Load the model from cloud storage
checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
print(f"model: {checkpoint_path}")

# Train the model with extended training period
trainer = Trainer(max_epochs=4)
trainer.fit(LitModel(), ckpt_path=checkpoint_path)
trainer.fit(BoringModel(), ckpt_path=checkpoint_path)
```

<details>
<summary>Advanced Checkpointing Workflow</summary>
<summary>Checkpointing Workflow with Lightning</summary>

Enhance your training process with an automatic checkpointing callback that uploads the best model at the end of each epoch.
While the example uses PyTorch Lightning callbacks, similar workflows can be implemented in any training loop that produces checkpoints.
Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch.

```python
import os
import torch.utils.data as data
import torchvision as tv
from lightning import Callback, Trainer
from litmodels import upload_model
from lightning import Trainer
from litmodels.integrations import LightningModelCheckpoint
from litmodels.demos import BoringModel

# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"


class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
# logging the computed loss
self.log("train_loss", loss)
return {"loss": loss}


class UploadModelCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
# Get the best model path from the checkpoint callback
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
if checkpoint_path and os.path.exists(checkpoint_path):
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)


dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

trainer = Trainer(
max_epochs=2,
callbacks=[UploadModelCallback()],
callbacks=[LightningModelCheckpoint(model_name=MY_MODEL_NAME)],
)
trainer.fit(
LitModel(),
BoringModel(),
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
)
Expand Down
Loading