@@ -51,9 +51,6 @@ Train your model using your preferred framework (our fist examples show `scikit-
5151from sklearn import datasets, model_selection, svm
5252from litmodels import upload_model
5353
54- # Unique model identifier: <organization>/<teamspace>/<model-name>
55- MY_MODEL_NAME = " your_org/your_team/sklearn-svm-model"
56-
5754# Load example dataset
5855iris = datasets.load_iris()
5956X, y = iris.data, iris.target
@@ -68,26 +65,43 @@ model = svm.SVC()
6865model.fit(X_train, y_train)
6966
7067# Upload the saved model using litmodels
71- upload_model(model = model, name = MY_MODEL_NAME )
68+ upload_model(model = model, name = " your_org/your_team/sklearn-svm-model " )
7269```
7370
7471### Download and Load the Model for inference
7572
7673``` python
7774from litmodels import load_model
7875
79- # Unique model identifier: <organization>/<teamspace>/<model-name>
80- MY_MODEL_NAME = " your_org/your_team/sklearn-svm-model"
81-
8276# Download and load the model file from cloud storage
83- model = load_model(name = MY_MODEL_NAME , download_dir = " my_models" )
77+ model = load_model(
78+ name = " your_org/your_team/sklearn-svm-model" , download_dir = " my_models"
79+ )
8480
8581# Example: run inference with the loaded model
8682sample_input = [[5.1 , 3.5 , 1.4 , 0.2 ]]
8783prediction = model.predict(sample_input)
8884print (f " Prediction: { prediction} " )
8985```
9086
87+ ## Saving and Loading Models with plain Pytorch
88+
89+ Next examples demonstrate seamless PyTorch integration with Lightning Models.
90+
91+ ``` python
92+ import torch
93+ from litmodels import load_model, upload_model
94+
95+
96+ class SimpleModel (torch .nn .Module ): ...
97+
98+
99+ # First, simply upload the model object to registry
100+ upload_model(model = SimpleModel(), name = " your_org/your_team/torch-model" )
101+ # Later, you can download the model from the registry
102+ model_ = load_model(name = " your_org/your_team/torch-model" )
103+ ```
104+
91105## Saving and Loading Models with Pytorch Lightning
92106
93107Next examples demonstrate seamless PyTorch Lightning integration with Lightning Models.
@@ -99,17 +113,15 @@ from lightning import Trainer
99113from litmodels import upload_model
100114from litmodels.demos import BoringModel
101115
102- # Define the model name - this should be unique to your model
103- MY_MODEL_NAME = " <organization>/<teamspace>/<model-name>"
104-
105116# Configure Lightning Trainer
106117trainer = Trainer(max_epochs = 2 )
107118# Define the model and train it
108119trainer.fit(BoringModel())
109120
110121# Upload the best model to cloud storage
111122checkpoint_path = getattr (trainer.checkpoint_callback, " best_model_path" )
112- upload_model(model = checkpoint_path, name = MY_MODEL_NAME )
123+ # Define the model name - this should be unique to your model
124+ upload_model(model = checkpoint_path, name = " <organization>/<teamspace>/<model-name>" )
113125```
114126
115127### Download and Load the Model for fine-tuning
@@ -119,11 +131,12 @@ from lightning import Trainer
119131from litmodels import download_model
120132from litmodels.demos import BoringModel
121133
122- # Define the model name - this should be unique to your model
123- MY_MODEL_NAME = " <organization>/<teamspace>/<model-name>:<model-version>"
124-
125134# Load the model from cloud storage
126- checkpoint_path = download_model(name = MY_MODEL_NAME , download_dir = " my_models" )
135+ checkpoint_path = download_model(
136+ # Define the model name and version - this needs to be unique to your model
137+ name = " <organization>/<teamspace>/<model-name>:<model-version>" ,
138+ download_dir = " my_models" ,
139+ )
127140print (f " model: { checkpoint_path} " )
128141
129142# Train the model with extended training period
@@ -143,15 +156,17 @@ from lightning import Trainer
143156from litmodels.integrations import LightningModelCheckpoint
144157from litmodels.demos import BoringModel
145158
146- # Define the model name - this should be unique to your model
147- MY_MODEL_NAME = " <organization>/<teamspace>/<model-name>"
148-
149159dataset = tv.datasets.MNIST(" ." , download = True , transform = tv.transforms.ToTensor())
150160train, val = data.random_split(dataset, [55000 , 5000 ])
151161
152162trainer = Trainer(
153163 max_epochs = 2 ,
154- callbacks = [LightningModelCheckpoint(model_name = MY_MODEL_NAME )],
164+ callbacks = [
165+ LightningModelCheckpoint(
166+ # Define the model name - this should be unique to your model
167+ model_name = " <organization>/<teamspace>/<model-name>" ,
168+ )
169+ ],
155170)
156171trainer.fit(
157172 BoringModel(),
0 commit comments