Skip to content

Commit a70865a

Browse files
authored
readme: update description (#36)
* readme: update description * examples * update * link
1 parent 4cc9d8a commit a70865a

File tree

6 files changed

+163
-86
lines changed

6 files changed

+163
-86
lines changed

README.md

Lines changed: 86 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,102 @@
1-
# Lightning Models
1+
# Effortless Model Management for Your Development ⚡
22

3-
This package provides utilities for saving and loading machine learning models using PyTorch Lightning. It aims to simplify the process of managing model checkpoints, making it easier to save, load, and share models.
3+
__Effortless management for your ML models.__
44

5-
## Features
5+
<div align="center">
66

7-
- **Save Models**: Easily save your trained models to cloud storage.
8-
- **Load Models**: Load pre-trained models for inference or further training.
9-
- **Checkpoint Management**: Manage multiple checkpoints with ease.
10-
- **Cloud Integration**: Support for saving and loading models from cloud storage services.
7+
🚀 [Quick start](#quick-start)
8+
📦 [Examples](#saving-and-loading-models)
9+
📚 [Documentation](https://lightning.ai/docs/overview/model-registry)
10+
💬 [Get help on Discord](https://discord.com/invite/XncpTy7DSt)
11+
📋 License: Apache 2.0
1112

12-
[![CI testing](https://github.com/Lightning-AI/models/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/models/actions/workflows/ci-testing.yml)
13-
[![General checks](https://github.com/Lightning-AI/models/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/models/actions/workflows/ci-checks.yml)
14-
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/models/main.svg?badge_token=mqheL1-cTn-280Vx4cJUdg)](https://results.pre-commit.ci/latest/github/Lightning-AI/models/main?badge_token=mqheL1-cTn-280Vx4cJUdg)
13+
</div>
1514

16-
## Installation
15+
**Lightning Models** is a streamlined toolkit for effortlessly saving, loading, and managing your model checkpoints. Designed to simplify the entire model lifecycle—from training and inference to sharing, deployment, and cloud integration—Lightning Models supports any framework that produces model checkpoints, including but not limited to PyTorch Lightning.
1716

18-
To install the package, you can use `pip`:
17+
<pre>
18+
✅ Seamless Model Saving & Loading
19+
✅ Robust Checkpoint Management
20+
✅ Cloud Integration Out of the Box
21+
✅ Versatile Across Frameworks
22+
</pre>
23+
24+
# Quick start
25+
26+
Install Lightning Models via pip (more installation options below):
1927

2028
```bash
2129
pip install -U litmodels
2230
```
2331

24-
Or installing from source:
32+
Or install directly from source:
2533

2634
```bash
2735
pip install https://github.com/Lightning-AI/models/archive/refs/heads/main.zip
2836
```
2937

30-
## Usage
38+
## Saving and Loading Models
39+
40+
Lightning Models offers a simple API to manage your model checkpoints.
41+
Train your model using your preferred framework (our fist examples show `scikit-learn`) and then save your best checkpoint with a single function call.
42+
43+
### Train scikit-learn model and save it
44+
45+
```python
46+
import joblib
47+
from sklearn import datasets, model_selection, svm
48+
from litmodels import upload_model
49+
50+
# Unique model identifier: <organization>/<teamspace>/<model-name>
51+
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"
52+
53+
# Load example dataset
54+
iris = datasets.load_iris()
55+
X, y = iris.data, iris.target
56+
57+
# Split dataset into training and test sets
58+
X_train, X_test, y_train, y_test = model_selection.train_test_split(
59+
X, y, test_size=0.2, random_state=42
60+
)
61+
62+
# Train a simple SVC model
63+
model = svm.SVC()
64+
model.fit(X_train, y_train)
65+
66+
# Upload the saved model using litmodels
67+
upload_model(model=model, name=MY_MODEL_NAME)
68+
```
69+
70+
### Download and Load the Model for inference
71+
72+
```python
73+
import os
74+
import joblib
75+
from litmodels import download_model
76+
77+
# Unique model identifier: <organization>/<teamspace>/<model-name>
78+
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"
79+
80+
# Download the model file from cloud storage
81+
model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
82+
83+
# Load the model for inference using joblib
84+
model = joblib.load(os.path.join("my_models", model_path[0]))
85+
86+
# Example: run inference with the loaded model
87+
sample_input = [[5.1, 3.5, 1.4, 0.2]]
88+
prediction = model.predict(sample_input)
89+
print(f"Prediction: {prediction}")
90+
```
91+
92+
## Saving and Loading Models with Pytorch Lightning
3193

32-
Here's a simple example of how to save and load a model using `litmodels`. First, you need to train a model using PyTorch Lightning. Then, you can save the model using the `upload_model` function.
94+
Next examples demonstrate seamless PyTorch Lightning integration with Lightning Models.
95+
96+
### Train a simple Lightning model and save it
3397

3498
```python
3599
from lightning import Trainer
36-
from lightning.pytorch.callbacks import ModelCheckpoint
37100
from litmodels import upload_model
38101
from litmodels.demos import BoringModel
39102

@@ -59,7 +122,7 @@ checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
59122
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
60123
```
61124

62-
To load the model, use the `download_model` function.
125+
### Download and Load the Model for fine-tuning
63126

64127
```python
65128
from lightning import Trainer
@@ -87,8 +150,11 @@ trainer = Trainer(max_epochs=4)
87150
trainer.fit(LitModel(), ckpt_path=checkpoint_path)
88151
```
89152

90-
You can also enhance your training with a simple Checkpointing callback which would always save the best model to the cloud storage and continue training.
91-
This can would be handy especially with long trainings or using interruptible machines so you would always resume/recover from the best model.
153+
<details>
154+
<summary>Advanced Checkpointing Workflow</summary>
155+
156+
Enhance your training process with an automatic checkpointing callback that uploads the best model at the end of each epoch.
157+
While the example uses PyTorch Lightning callbacks, similar workflows can be implemented in any training loop that produces checkpoints.
92158

93159
```python
94160
import os
@@ -132,69 +198,4 @@ trainer.fit(
132198
)
133199
```
134200

135-
## Logging Models
136-
137-
You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage.
138-
139-
```python
140-
import os
141-
import lightning as L
142-
from psutil import cpu_count
143-
from torch import optim, nn
144-
from torch.utils.data import DataLoader
145-
from torchvision.datasets import MNIST
146-
from torchvision.transforms import ToTensor
147-
from litlogger import LightningLogger
148-
149-
150-
class LitAutoEncoder(L.LightningModule):
151-
152-
def __init__(self, lr=1e-3, inp_size=28):
153-
super().__init__()
154-
155-
self.encoder = nn.Sequential(
156-
nn.Linear(inp_size * inp_size, 64), nn.ReLU(), nn.Linear(64, 3)
157-
)
158-
self.decoder = nn.Sequential(
159-
nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, inp_size * inp_size)
160-
)
161-
self.lr = lr
162-
self.save_hyperparameters()
163-
164-
def training_step(self, batch, batch_idx):
165-
x, y = batch
166-
x = x.view(x.size(0), -1)
167-
z = self.encoder(x)
168-
x_hat = self.decoder(z)
169-
loss = nn.functional.mse_loss(x_hat, x)
170-
# log metrics
171-
self.log("train_loss", loss)
172-
return loss
173-
174-
def configure_optimizers(self):
175-
optimizer = optim.Adam(self.parameters(), lr=self.lr)
176-
return optimizer
177-
178-
179-
if __name__ == "__main__":
180-
# init the autoencoder
181-
autoencoder = LitAutoEncoder(lr=1e-3, inp_size=28)
182-
183-
# setup data
184-
train_loader = DataLoader(
185-
dataset=MNIST(os.getcwd(), download=True, transform=ToTensor()),
186-
batch_size=32,
187-
shuffle=True,
188-
num_workers=cpu_count(),
189-
persistent_workers=True,
190-
)
191-
192-
# configure the logger
193-
lit_logger = LightningLogger(log_model=True)
194-
195-
# pass logger to the Trainer
196-
trainer = L.Trainer(max_epochs=5, logger=lit_logger)
197-
198-
# train the model
199-
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
200-
```
201+
</details>
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
This example demonstrates how to resume training of a model using the `download_model` function.
3+
"""
4+
15
import torch.utils.data as data
26
import torchvision as tv
37
from lightning import Trainer
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
This example demonstrates how to train a model and upload it to the cloud using the `upload_model` function.
3+
"""
4+
15
import torch.utils.data as data
26
import torchvision as tv
37
from lightning import Trainer

examples/train-callback.py renamed to examples/train-model-with-lightning-callback.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Train a model with a Lightning callback that uploads the best model to the cloud after each epoch.
3+
"""
4+
15
import torch.utils.data as data
26
import torchvision as tv
37
from lightning import Callback, Trainer
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
# Enhanced Logging with LightningLogger
3+
4+
Integrate with [LitLogger](https://github.com/gridai/lit-logger) to automatically log your model checkpoints
5+
and training metrics to cloud storage.
6+
Though the example utilizes PyTorch Lightning, this integration concept works across various model training frameworks.
7+
8+
"""
9+
10+
import os
11+
12+
from lightning import LightningModule, Trainer
13+
from litlogger import LightningLogger
14+
from psutil import cpu_count
15+
from torch import nn, optim
16+
from torch.utils.data import DataLoader
17+
from torchvision.datasets import MNIST
18+
from torchvision.transforms import ToTensor
19+
20+
21+
class LitAutoEncoder(LightningModule):
22+
def __init__(self, lr=1e-3, inp_size=28):
23+
super().__init__()
24+
25+
self.encoder = nn.Sequential(nn.Linear(inp_size * inp_size, 64), nn.ReLU(), nn.Linear(64, 3))
26+
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, inp_size * inp_size))
27+
self.lr = lr
28+
self.save_hyperparameters()
29+
30+
def training_step(self, batch, batch_idx):
31+
x, y = batch
32+
x = x.view(x.size(0), -1)
33+
z = self.encoder(x)
34+
x_hat = self.decoder(z)
35+
loss = nn.functional.mse_loss(x_hat, x)
36+
# log metrics
37+
self.log("train_loss", loss)
38+
return loss
39+
40+
def configure_optimizers(self):
41+
return optim.Adam(self.parameters(), lr=self.lr)
42+
43+
44+
if __name__ == "__main__":
45+
# init the autoencoder
46+
autoencoder = LitAutoEncoder(lr=1e-3, inp_size=28)
47+
48+
# setup data
49+
train_loader = DataLoader(
50+
dataset=MNIST(os.getcwd(), download=True, transform=ToTensor()),
51+
batch_size=32,
52+
shuffle=True,
53+
num_workers=cpu_count(),
54+
persistent_workers=True,
55+
)
56+
57+
# configure the logger
58+
lit_logger = LightningLogger(log_model=True)
59+
60+
# pass logger to the Trainer
61+
trainer = Trainer(max_epochs=5, logger=lit_logger)
62+
63+
# train the model
64+
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

src/litmodels/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.0.9rc"
1+
__version__ = "0.1.0"
22
__author__ = "Lightning-AI et al."
33
__author_email__ = "[email protected]"
44
__license__ = "Apache-2.0"

0 commit comments

Comments
 (0)