Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/litmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from litmodels.cloud_io import download_model, upload_model, upload_model_files
from litmodels.io import download_model, upload_model

__all__ = ["download_model", "upload_model", "upload_model_files"]
__all__ = ["download_model", "upload_model"]
6 changes: 6 additions & 0 deletions src/litmodels/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Root package for Input/output."""

from litmodels.io.cloud import download_model_files, upload_model_files
from litmodels.io.gateway import download_model, upload_model

__all__ = ["download_model", "upload_model", "download_model_files", "upload_model_files"]
67 changes: 6 additions & 61 deletions src/litmodels/cloud_io.py → src/litmodels/io/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,11 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple, Union
from typing import Optional, Tuple

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers
from lightning_utilities import module_available

if TYPE_CHECKING:
from torch.nn import Module

if module_available("torch"):
import torch
from torch.nn import Module
else:
torch = None

# if module_available("lightning"):
# from lightning import LightningModule
Expand Down Expand Up @@ -63,60 +50,18 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
return Teamspace(**teamspaces[requested_teamspace])


def upload_model(
model: Union[str, Path, "Module"],
name: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
staging_dir: Optional[str] = None,
) -> UploadedModelInfo:
"""Upload a checkpoint to the model store.

Args:
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
be created and used.

"""
if not staging_dir:
staging_dir = tempfile.mkdtemp()
# if LightningModule and isinstance(model, LightningModule):
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
# model.save_checkpoint(path)
if torch and isinstance(model, Module):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
torch.save(model.state_dict(), path)
elif isinstance(model, str):
path = model
elif isinstance(model, Path):
path = str(model)
else:
raise ValueError(f"Unsupported model type {type(model)}")
return upload_model_files(
path=path,
name=name,
progress_bar=progress_bar,
cluster_id=cluster_id,
)


def upload_model_files(
path: str,
name: str,
path: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
) -> UploadedModelInfo:
"""Upload a local checkpoint file to the model store.

Args:
path: Path to the model file to upload.
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
path: Path to the model file to upload.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.
Expand All @@ -132,15 +77,15 @@ def upload_model_files(
)


def download_model(
def download_model_files(
name: str,
download_dir: str = ".",
progress_bar: bool = True,
) -> str:
"""Download a checkpoint from the model store.

Args:
name: Name tag of the model to download. Must be in the format 'organization/teamspace/modelname'
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
download_dir: A path to directory where the model should be downloaded. Defaults
to the current working directory.
Expand Down
81 changes: 81 additions & 0 deletions src/litmodels/io/gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_utilities import module_available

from litmodels.io.cloud import download_model_files, upload_model_files

if module_available("torch"):
import torch
from torch.nn import Module
else:
torch = None


def upload_model(
name: str,
model: Union[str, Path, "Module"],
progress_bar: bool = True,
cluster_id: Optional[str] = None,
staging_dir: Optional[str] = None,
) -> UploadedModelInfo:
"""Upload a checkpoint to the model store.

Args:
name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
be created and used.

"""
if not staging_dir:
staging_dir = tempfile.mkdtemp()
# if LightningModule and isinstance(model, LightningModule):
# path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
# model.save_checkpoint(path)
if torch and isinstance(model, Module):
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth")
torch.save(model.state_dict(), path)
elif isinstance(model, str):
path = model
elif isinstance(model, Path):
path = str(model)
else:
raise ValueError(f"Unsupported model type {type(model)}")
return upload_model_files(
path=path,
name=name,
progress_bar=progress_bar,
cluster_id=cluster_id,
)


def download_model(
name: str,
download_dir: str = ".",
progress_bar: bool = True,
) -> str:
"""Download a checkpoint from the model store.

Args:
name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
download_dir: A path to directory where the model should be downloaded. Defaults
to the current working directory.
progress_bar: Whether to show a progress bar for the download.

Returns:
The absolute path to the downloaded model file or folder.
"""
return download_model_files(
name=name,
download_dir=download_dir,
progress_bar=progress_bar,
)
9 changes: 5 additions & 4 deletions tests/test_cloud_io.py → tests/test_io_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from unittest import mock

import pytest
from litmodels.cloud_io import download_model, upload_model, upload_model_files
from litmodels import download_model, upload_model
from litmodels.io import upload_model_files
from torch.nn import Module


Expand All @@ -25,11 +26,11 @@ def test_wrong_model_name(name):
def test_upload_model(mocker, tmpdir, model, model_path):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)

# The lit-logger function is just a wrapper around the SDK function
upload_model(
model,
model=model,
name="org-name/teamspace/model-name",
cluster_id="cluster_id",
staging_dir=tmpdir,
Expand All @@ -46,7 +47,7 @@ def test_upload_model(mocker, tmpdir, model, model_path):
def test_download_model(mocker):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
mocker.patch("litmodels.io.cloud._get_teamspace", return_value=ts_mock)
# The lit-logger function is just a wrapper around the SDK function
download_model(
name="org-name/teamspace/model-name",
Expand Down
Loading