Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ Train your model using your preferred framework (our fist examples show `scikit-
from sklearn import datasets, model_selection, svm
from litmodels import upload_model

# Unique model identifier: <organization>/<teamspace>/<model-name>
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"

# Load example dataset
iris = datasets.load_iris()
X, y = iris.data, iris.target
Expand All @@ -68,26 +65,43 @@ model = svm.SVC()
model.fit(X_train, y_train)

# Upload the saved model using litmodels
upload_model(model=model, name=MY_MODEL_NAME)
upload_model(model=model, name="your_org/your_team/sklearn-svm-model")
```

### Download and Load the Model for inference

```python
from litmodels import load_model

# Unique model identifier: <organization>/<teamspace>/<model-name>
MY_MODEL_NAME = "your_org/your_team/sklearn-svm-model"

# Download and load the model file from cloud storage
model = load_model(name=MY_MODEL_NAME, download_dir="my_models")
model = load_model(
name="your_org/your_team/sklearn-svm-model", download_dir="my_models"
)

# Example: run inference with the loaded model
sample_input = [[5.1, 3.5, 1.4, 0.2]]
prediction = model.predict(sample_input)
print(f"Prediction: {prediction}")
```

## Saving and Loading Models with plain Pytorch

Next examples demonstrate seamless PyTorch integration with Lightning Models.

```python
import torch
from litmodels import load_model, upload_model


class SimpleModel(torch.nn.Module): ...


# First, simply upload the model object to registry
upload_model(model=SimpleModel(), name="your_org/your_team/torch-model")
# Later, you can download the model from the registry
model_ = load_model(name="your_org/your_team/torch-model")
```

## Saving and Loading Models with Pytorch Lightning

Next examples demonstrate seamless PyTorch Lightning integration with Lightning Models.
Expand All @@ -99,17 +113,15 @@ from lightning import Trainer
from litmodels import upload_model
from litmodels.demos import BoringModel

# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"

# Configure Lightning Trainer
trainer = Trainer(max_epochs=2)
# Define the model and train it
trainer.fit(BoringModel())

# Upload the best model to cloud storage
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
# Define the model name - this should be unique to your model
upload_model(model=checkpoint_path, name="<organization>/<teamspace>/<model-name>")
```

### Download and Load the Model for fine-tuning
Expand All @@ -119,11 +131,12 @@ from lightning import Trainer
from litmodels import download_model
from litmodels.demos import BoringModel

# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>:<model-version>"

# Load the model from cloud storage
checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
checkpoint_path = download_model(
# Define the model name and version - this needs to be unique to your model
name="<organization>/<teamspace>/<model-name>:<model-version>",
download_dir="my_models",
)
print(f"model: {checkpoint_path}")

# Train the model with extended training period
Expand All @@ -143,15 +156,17 @@ from lightning import Trainer
from litmodels.integrations import LightningModelCheckpoint
from litmodels.demos import BoringModel

# Define the model name - this should be unique to your model
MY_MODEL_NAME = "<organization>/<teamspace>/<model-name>"

dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

trainer = Trainer(
max_epochs=2,
callbacks=[LightningModelCheckpoint(model_name=MY_MODEL_NAME)],
callbacks=[
LightningModelCheckpoint(
# Define the model name - this should be unique to your model
model_name="<organization>/<teamspace>/<model-name>",
)
],
)
trainer.fit(
BoringModel(),
Expand Down
Loading