|
1 | 1 | # Lightning Models |
2 | 2 |
|
3 | | -...TBD... |
| 3 | +This package provides utilities for saving and loading machine learning models using PyTorch Lightning. It aims to simplify the process of managing model checkpoints, making it easier to save, load, and share models. |
| 4 | + |
| 5 | +## Features |
| 6 | + |
| 7 | +- **Save Models**: Easily save your trained models to cloud storage. |
| 8 | +- **Load Models**: Load pre-trained models for inference or further training. |
| 9 | +- **Checkpoint Management**: Manage multiple checkpoints with ease. |
| 10 | +- **Cloud Integration**: Support for saving and loading models from cloud storage services. |
4 | 11 |
|
5 | 12 | [](https://lightning.ai/) |
6 | 13 | [](https://github.com/Lightning-AI/models/actions/workflows/ci-testing.yml) |
7 | 14 | [](https://github.com/Lightning-AI/models/actions/workflows/ci-checks.yml) |
8 | 15 | [](https://models.readthedocs.io/en/latest/?badge=latest) |
9 | 16 | [](https://results.pre-commit.ci/latest/github/Lightning-AI/models/main?badge_token=mqheL1-cTn-280Vx4cJUdg) |
| 17 | + |
| 18 | +## Installation |
| 19 | + |
| 20 | +To install the package, you can use `pip` from [Test PyPI](https://test.pypi.org/project/litmodels/): |
| 21 | + |
| 22 | +```bash |
| 23 | +pip install "litmodels==0.X.Y" --extra-index-url="https://test.pypi.org/simple/" |
| 24 | +``` |
| 25 | + |
| 26 | +Or installing from source: |
| 27 | + |
| 28 | +```bash |
| 29 | +pip install https://github.com/Lightning-AI/models/archive/refs/heads/main.zip |
| 30 | +``` |
| 31 | + |
| 32 | +## Usage |
| 33 | + |
| 34 | +Here's a simple example of how to save and load a model using `litmodels`. First, you need to train a model using PyTorch Lightning. Then, you can save the model using the `upload_model` function. |
| 35 | + |
| 36 | +```python |
| 37 | +from lightning import Trainer |
| 38 | +from lightning.pytorch.callbacks import ModelCheckpoint |
| 39 | +from litmodels import upload_model |
| 40 | +from litmodels.demos import BoringModel |
| 41 | + |
| 42 | + |
| 43 | +class LitModel(BoringModel): |
| 44 | + def training_step(self, batch, batch_idx: int): |
| 45 | + loss = self.step(batch) |
| 46 | + # logging the computed loss |
| 47 | + self.log("train_loss", loss) |
| 48 | + return {"loss": loss} |
| 49 | + |
| 50 | + |
| 51 | +# Define the model name - this should be unique to your model |
| 52 | +# The format is <organization>/<teamspace>/<model-name> |
| 53 | +MY_MODEL_NAME = "jirka/kaggle/lit-boring-model" |
| 54 | + |
| 55 | +# Define the model |
| 56 | +model = LitModel() |
| 57 | +# Save the best model based on validation loss |
| 58 | +checkpoint_callback = ModelCheckpoint( |
| 59 | + monitor="train_loss", # Metric to monitor |
| 60 | + save_top_k=1, # Only save the best model (use -1 to save all) |
| 61 | + mode="min", # 'min' for loss, 'max' for accuracy |
| 62 | + save_last=True, # Additionally save the last checkpoint |
| 63 | + dirpath="my_checkpoints/", # Directory to save checkpoints |
| 64 | + filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename |
| 65 | +) |
| 66 | + |
| 67 | +# Train the model |
| 68 | +trainer = Trainer( |
| 69 | + max_epochs=2, |
| 70 | + callbacks=[checkpoint_callback], |
| 71 | +) |
| 72 | +trainer.fit(model) |
| 73 | + |
| 74 | +# Upload the best model to cloud storage |
| 75 | +print(f"last: {vars(checkpoint_callback)}") |
| 76 | +upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME) |
| 77 | +``` |
| 78 | + |
| 79 | +To load the model, use the `load_model` function. |
| 80 | + |
| 81 | +```python |
| 82 | +from lightning import Trainer |
| 83 | +from litmodels import download_model |
| 84 | +from litmodels.demos import BoringModel |
| 85 | + |
| 86 | + |
| 87 | +class LitModel(BoringModel): |
| 88 | + def training_step(self, batch, batch_idx: int): |
| 89 | + loss = self.step(batch) |
| 90 | + # logging the computed loss |
| 91 | + self.log("train_loss", loss) |
| 92 | + return {"loss": loss} |
| 93 | + |
| 94 | + |
| 95 | +# Define the model name - this should be unique to your model |
| 96 | +# The format is <organization>/<teamspace>/<model-name>:<model-version> |
| 97 | +MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest" |
| 98 | + |
| 99 | +# Load the model from cloud storage |
| 100 | +model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models") |
| 101 | +print(f"model: {model_path}") |
| 102 | + |
| 103 | +# Train the model with extended training period |
| 104 | +trainer = Trainer(max_epochs=4) |
| 105 | +trainer.fit( |
| 106 | + LitModel(), |
| 107 | + ckpt_path=model_path, |
| 108 | +) |
| 109 | +``` |
0 commit comments