Skip to content

Commit 2d01486

Browse files
authored
add simple Readme (#19)
1 parent 2b6faa7 commit 2d01486

File tree

4 files changed

+103
-6
lines changed

4 files changed

+103
-6
lines changed

.github/workflows/docs-build.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ name: "Build (& deploy) Docs"
22
on:
33
push:
44
branches: [main]
5-
pull_request:
6-
branches: [main]
75
workflow_dispatch:
86

97
jobs:
@@ -15,7 +13,6 @@ jobs:
1513
# https://github.com/marketplace/actions/deploy-to-github-pages
1614
docs-deploy:
1715
needs: build-docs
18-
if: github.event_name != 'pull_request'
1916
runs-on: ubuntu-latest
2017
steps:
2118
- uses: actions/checkout@v4 # deploy needs git credentials

README.md

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,109 @@
11
# Lightning Models
22

3-
...TBD...
3+
This package provides utilities for saving and loading machine learning models using PyTorch Lightning. It aims to simplify the process of managing model checkpoints, making it easier to save, load, and share models.
4+
5+
## Features
6+
7+
- **Save Models**: Easily save your trained models to cloud storage.
8+
- **Load Models**: Load pre-trained models for inference or further training.
9+
- **Checkpoint Management**: Manage multiple checkpoints with ease.
10+
- **Cloud Integration**: Support for saving and loading models from cloud storage services.
411

512
[![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://lightning.ai/)
613
[![CI testing](https://github.com/Lightning-AI/models/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/models/actions/workflows/ci-testing.yml)
714
[![General checks](https://github.com/Lightning-AI/models/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/models/actions/workflows/ci-checks.yml)
815
[![Documentation Status](https://readthedocs.org/projects/models/badge/?version=latest)](https://models.readthedocs.io/en/latest/?badge=latest)
916
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/models/main.svg?badge_token=mqheL1-cTn-280Vx4cJUdg)](https://results.pre-commit.ci/latest/github/Lightning-AI/models/main?badge_token=mqheL1-cTn-280Vx4cJUdg)
17+
18+
## Installation
19+
20+
To install the package, you can use `pip` from [Test PyPI](https://test.pypi.org/project/litmodels/):
21+
22+
```bash
23+
pip install "litmodels==0.X.Y" --extra-index-url="https://test.pypi.org/simple/"
24+
```
25+
26+
Or installing from source:
27+
28+
```bash
29+
pip install https://github.com/Lightning-AI/models/archive/refs/heads/main.zip
30+
```
31+
32+
## Usage
33+
34+
Here's a simple example of how to save and load a model using `litmodels`. First, you need to train a model using PyTorch Lightning. Then, you can save the model using the `upload_model` function.
35+
36+
```python
37+
from lightning import Trainer
38+
from lightning.pytorch.callbacks import ModelCheckpoint
39+
from litmodels import upload_model
40+
from litmodels.demos import BoringModel
41+
42+
43+
class LitModel(BoringModel):
44+
def training_step(self, batch, batch_idx: int):
45+
loss = self.step(batch)
46+
# logging the computed loss
47+
self.log("train_loss", loss)
48+
return {"loss": loss}
49+
50+
51+
# Define the model name - this should be unique to your model
52+
# The format is <organization>/<teamspace>/<model-name>
53+
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model"
54+
55+
# Define the model
56+
model = LitModel()
57+
# Save the best model based on validation loss
58+
checkpoint_callback = ModelCheckpoint(
59+
monitor="train_loss", # Metric to monitor
60+
save_top_k=1, # Only save the best model (use -1 to save all)
61+
mode="min", # 'min' for loss, 'max' for accuracy
62+
save_last=True, # Additionally save the last checkpoint
63+
dirpath="my_checkpoints/", # Directory to save checkpoints
64+
filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename
65+
)
66+
67+
# Train the model
68+
trainer = Trainer(
69+
max_epochs=2,
70+
callbacks=[checkpoint_callback],
71+
)
72+
trainer.fit(model)
73+
74+
# Upload the best model to cloud storage
75+
print(f"last: {vars(checkpoint_callback)}")
76+
upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME)
77+
```
78+
79+
To load the model, use the `load_model` function.
80+
81+
```python
82+
from lightning import Trainer
83+
from litmodels import download_model
84+
from litmodels.demos import BoringModel
85+
86+
87+
class LitModel(BoringModel):
88+
def training_step(self, batch, batch_idx: int):
89+
loss = self.step(batch)
90+
# logging the computed loss
91+
self.log("train_loss", loss)
92+
return {"loss": loss}
93+
94+
95+
# Define the model name - this should be unique to your model
96+
# The format is <organization>/<teamspace>/<model-name>:<model-version>
97+
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest"
98+
99+
# Load the model from cloud storage
100+
model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
101+
print(f"model: {model_path}")
102+
103+
# Train the model with extended training period
104+
trainer = Trainer(max_epochs=4)
105+
trainer.fit(
106+
LitModel(),
107+
ckpt_path=model_path,
108+
)
109+
```

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _convert_markdown(path_in: str, path_out: str) -> None:
8484

8585

8686
# export the READme
87-
_convert_markdown(os.path.join(_PATH_ROOT, "README.md"), "readme.md")
87+
_convert_markdown(os.path.join(_PATH_ROOT, "README.md"), "readme.rst")
8888

8989
# -- General configuration ---------------------------------------------------
9090

examples/train-simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# Define the model name - this should be unique to your model
99
# The format is <organization>/<teamspace>/<model-name>
10-
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback"
10+
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-simple"
1111

1212

1313
if __name__ == "__main__":

0 commit comments

Comments
 (0)