diff --git a/README.md b/README.md index 72556b7..0589bb9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -
+from litmodels.integrations import LightningModelCheckpoint
# Effortless Model Management for Your Development ⚡ @@ -102,19 +102,10 @@ from litmodels.demos import BoringModel # Define the model name - this should be unique to your model MY_MODEL_NAME = "//" - -class LitModel(BoringModel): - def training_step(self, batch, batch_idx: int): - loss = self.step(batch) - # logging the computed loss - self.log("train_loss", loss) - return {"loss": loss} - - # Configure Lightning Trainer trainer = Trainer(max_epochs=2) # Define the model and train it -trainer.fit(LitModel()) +trainer.fit(BoringModel()) # Upload the best model to cloud storage checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") @@ -131,67 +122,39 @@ from litmodels.demos import BoringModel # Define the model name - this should be unique to your model MY_MODEL_NAME = "//:" - -class LitModel(BoringModel): - def training_step(self, batch, batch_idx: int): - loss = self.step(batch) - # logging the computed loss - self.log("train_loss", loss) - return {"loss": loss} - - # Load the model from cloud storage checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models") print(f"model: {checkpoint_path}") # Train the model with extended training period trainer = Trainer(max_epochs=4) -trainer.fit(LitModel(), ckpt_path=checkpoint_path) +trainer.fit(BoringModel(), ckpt_path=checkpoint_path) ```
- Advanced Checkpointing Workflow + Checkpointing Workflow with Lightning -Enhance your training process with an automatic checkpointing callback that uploads the best model at the end of each epoch. -While the example uses PyTorch Lightning callbacks, similar workflows can be implemented in any training loop that produces checkpoints. +Enhance your training process with an automatic checkpointing callback that uploads the model at the end of each epoch. ```python -import os import torch.utils.data as data import torchvision as tv -from lightning import Callback, Trainer -from litmodels import upload_model +from lightning import Trainer +from litmodels.integrations import LightningModelCheckpoint from litmodels.demos import BoringModel # Define the model name - this should be unique to your model MY_MODEL_NAME = "//" - -class LitModel(BoringModel): - def training_step(self, batch, batch_idx: int): - loss = self.step(batch) - # logging the computed loss - self.log("train_loss", loss) - return {"loss": loss} - - -class UploadModelCallback(Callback): - def on_train_epoch_end(self, trainer, pl_module): - # Get the best model path from the checkpoint callback - checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") - if checkpoint_path and os.path.exists(checkpoint_path): - upload_model(model=checkpoint_path, name=MY_MODEL_NAME) - - dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor()) train, val = data.random_split(dataset, [55000, 5000]) trainer = Trainer( max_epochs=2, - callbacks=[UploadModelCallback()], + callbacks=[LightningModelCheckpoint(model_name=MY_MODEL_NAME)], ) trainer.fit( - LitModel(), + BoringModel(), data.DataLoader(train, batch_size=256), data.DataLoader(val, batch_size=256), )