Skip to content

Commit d3404da

Browse files
authored
update examples in Readme (#79)
1 parent 2270db2 commit d3404da

File tree

1 file changed

+35
-20
lines changed

1 file changed

+35
-20
lines changed

README.md

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ Train your model using your preferred framework (our fist examples show `scikit-
5151
from sklearn import datasets, model_selection, svm
5252
from litmodels import upload_model
5353

54-
# Unique model identifier: <organization>/<teamspace>/<model-name>
55-
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"
56-
5754
# Load example dataset
5855
iris = datasets.load_iris()
5956
X, y = iris.data, iris.target
@@ -68,26 +65,43 @@ model = svm.SVC()
6865
model.fit(X_train, y_train)
6966

7067
# Upload the saved model using litmodels
71-
upload_model(model=model, name=MY_MODEL_NAME)
68+
upload_model(model=model, name="your_org/your_team/sklearn-svm-model")
7269
```
7370

7471
### Download and Load the Model for inference
7572

7673
```python
7774
from litmodels import load_model
7875

79-
# Unique model identifier: <organization>/<teamspace>/<model-name>
80-
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"
81-
8276
# Download and load the model file from cloud storage
83-
model = load_model(name=MY_MODEL_NAME, download_dir="my_models")
77+
model = load_model(
78+
name="your_org/your_team/sklearn-svm-model", download_dir="my_models"
79+
)
8480

8581
# Example: run inference with the loaded model
8682
sample_input = [[5.1, 3.5, 1.4, 0.2]]
8783
prediction = model.predict(sample_input)
8884
print(f"Prediction: {prediction}")
8985
```
9086

87+
## Saving and Loading Models with plain Pytorch
88+
89+
Next examples demonstrate seamless PyTorch integration with Lightning Models.
90+
91+
```python
92+
import torch
93+
from litmodels import load_model, upload_model
94+
95+
96+
class SimpleModel(torch.nn.Module): ...
97+
98+
99+
# First, simply upload the model object to registry
100+
upload_model(model=SimpleModel(), name="your_org/your_team/torch-model")
101+
# Later, you can download the model from the registry
102+
model_ = load_model(name="your_org/your_team/torch-model")
103+
```
104+
91105
## Saving and Loading Models with Pytorch Lightning
92106

93107
Next examples demonstrate seamless PyTorch Lightning integration with Lightning Models.
@@ -99,17 +113,15 @@ from lightning import Trainer
99113
from litmodels import upload_model
100114
from litmodels.demos import BoringModel
101115

102-
# Define the model name - this should be unique to your model
103-
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"
104-
105116
# Configure Lightning Trainer
106117
trainer = Trainer(max_epochs=2)
107118
# Define the model and train it
108119
trainer.fit(BoringModel())
109120

110121
# Upload the best model to cloud storage
111122
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
112-
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
123+
# Define the model name - this should be unique to your model
124+
upload_model(model=checkpoint_path, name="<organization>/<teamspace>/<model-name>")
113125
```
114126

115127
### Download and Load the Model for fine-tuning
@@ -119,11 +131,12 @@ from lightning import Trainer
119131
from litmodels import download_model
120132
from litmodels.demos import BoringModel
121133

122-
# Define the model name - this should be unique to your model
123-
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>:<model-version>"
124-
125134
# Load the model from cloud storage
126-
checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
135+
checkpoint_path = download_model(
136+
# Define the model name and version - this needs to be unique to your model
137+
name="<organization>/<teamspace>/<model-name>:<model-version>",
138+
download_dir="my_models",
139+
)
127140
print(f"model: {checkpoint_path}")
128141

129142
# Train the model with extended training period
@@ -143,15 +156,17 @@ from lightning import Trainer
143156
from litmodels.integrations import LightningModelCheckpoint
144157
from litmodels.demos import BoringModel
145158

146-
# Define the model name - this should be unique to your model
147-
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"
148-
149159
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
150160
train, val = data.random_split(dataset, [55000, 5000])
151161

152162
trainer = Trainer(
153163
max_epochs=2,
154-
callbacks=[LightningModelCheckpoint(model_name=MY_MODEL_NAME)],
164+
callbacks=[
165+
LightningModelCheckpoint(
166+
# Define the model name - this should be unique to your model
167+
model_name="<organization>/<teamspace>/<model-name>",
168+
)
169+
],
155170
)
156171
trainer.fit(
157172
BoringModel(),

0 commit comments

Comments
 (0)