Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
uses: Lightning-AI/utilities/.github/workflows/check-package.yml@main
with:
actions-ref: main
import-name: "lit_sandbox"
import-name: "litmodels"
artifact-name: dist-packages-${{ github.sha }}
testing-matrix: |
{
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ jobs:

- name: Tests
run: |
coverage run --source lit_sandbox -m pytest src tests -v
coverage run --source litmodels -m pytest src tests -v

- name: Statistics
if: success()
run: |
coverage report
coverage xml
Expand Down
14 changes: 8 additions & 6 deletions .github/workflows/release-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:

- name: Install dependencies
run: pip install -U build twine
- name: Overview Readme for release
run: echo "# Lit Models" > README.md
- name: Build package
run: python -m build
- name: Check package
Expand All @@ -35,9 +37,9 @@ jobs:
password: ${{ secrets.test_pypi_password }}
repository_url: https://test.pypi.org/legacy/

- name: Publish distribution 📦 to PyPI
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
uses: pypa/[email protected]
with:
user: __token__
password: ${{ secrets.pypi_password }}
#- name: Publish distribution 📦 to PyPI
# if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
# uses: pypa/[email protected]
# with:
# user: __token__
# password: ${{ secrets.pypi_password }}
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test: clean
pip install -q -r _requirements/test.txt

# use this to run tests
python -m coverage run --source lit_sandbox -m pytest src tests -v --flake8
python -m coverage run --source litmodels -m pytest src tests -v --flake8
python -m coverage report

docs: clean
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This is starter project template which shall simplify initial steps for each new

Listing the implemented sections:

- sample package named `lit_sandbox`
- sample package named `litmodels`
- setting [CI](https://github.com/Lightning-AI/lightning-sandbox/actions?query=workflow%3A%22CI+testing%22) for package and _tests_ folder
- setup/install package
- setting docs with Sphinx
Expand All @@ -26,7 +26,6 @@ Listing the implemented sections:

You still need to enable some external integrations such as:

- [ ] rename `pl_<sandbox>` to anu other name, simple find-replace shall work well
- [ ] update path used in the badges to the repository
- [ ] lock the main breach in GH setting - no direct push without PR
- [ ] set `gh-pages` as website and _docs_ as source folder in GH setting
Expand Down
7 changes: 4 additions & 3 deletions _requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
coverage>=5.0
pytest>=6.0
coverage >=5.0
pytest >=6.0
pytest-cov
mypy==1.13.0
pytest-mock
mypy ==1.13.0
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# alternative https://stackoverflow.com/a/67692/4521646
spec = spec_from_file_location(
"lit_sandbox/__about__.py", os.path.join(_PATH_SOURCE, "lit_sandbox", "__about__.py")
"litmodels/__about__.py", os.path.join(_PATH_SOURCE, "litmodels", "__about__.py")
)
about = module_from_spec(spec)
spec.loader.exec_module(about)
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,16 @@ extend-select = [
]
ignore = [
"E731", # Do not assign a lambda expression, use a def
"D100", # Missing docstring in public module
]
# Exclude a variety of commonly ignored directories.
ignore-init-module-imports = true

[tool.ruff.lint.per-file-ignores]
"setup.py" = ["D100", "SIM115"]
"__about__.py" = ["D100"]
"__init__.py" = ["D100"]
"__init__.py" = ["D100", "E402"]
"tests/**" = ["D"]

[tool.ruff.lint.pydocstyle]
# Use Google-style docstrings.
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
torch >=2.0.0
lightning >=2.0.0
lightning-sdk >=0.1.26
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements")


def _load_py_module(fname, pkg="lit_sandbox"):
def _load_py_module(fname, pkg="litmodels"):
spec = spec_from_file_location(os.path.join(pkg, fname), os.path.join(_PATH_SOURCE, pkg, fname))
py = module_from_spec(spec)
spec.loader.exec_module(py)
Expand Down
8 changes: 0 additions & 8 deletions src/lit_sandbox/__init__.py

This file was deleted.

14 changes: 0 additions & 14 deletions src/lit_sandbox/my_module.py

This file was deleted.

File renamed without changes.
12 changes: 12 additions & 0 deletions src/litmodels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Root package info."""

import os

from litmodels.__about__ import * # noqa: F401, F403

_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from litmodels.cloud_io import download_model, upload_model

__all__ = ["download_model", "upload_model"]
92 changes: 92 additions & 0 deletions src/litmodels/cloud_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Optional, Tuple

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers


def _parse_name(name: str) -> Tuple[str, str, str]:
"""Parse the name argument into its components."""
try:
org_name, teamspace_name, model_name = name.split("/")
except ValueError as err:
raise ValueError(
f"The name argument must be in the format 'organization/teamspace/model` but you provided '{name}'."
) from err
return org_name, teamspace_name, model_name


def _get_teamspace(name: str, organization: str) -> Teamspace:
"""Get a Teamspace object from the SDK."""
from lightning_sdk.api import OrgApi, UserApi

org_api = OrgApi()
user = sdk_resolvers._get_authed_user()
teamspaces = {}
for ts in UserApi()._get_all_teamspace_memberships(""):
if ts.owner_type == "organization":
org = org_api._get_org_by_id(ts.owner_id)
teamspaces[f"{org.name}/{ts.name}"] = {"name": ts.name, "org": org.name}
elif ts.owner_type == "user": # todo: check also the name
teamspaces[f"{user.name}/{ts.name}"] = {"name": ts.name, "user": user}
else:
raise RuntimeError(f"Unknown organization type {ts.organization_type}")

requested_teamspace = f"{organization}/{name}".lower()
if requested_teamspace not in teamspaces:
options = "\n\t".join(teamspaces.keys())
raise RuntimeError(f"Teamspace `{requested_teamspace}` not found. Available teamspaces: \n\t{options}")
return Teamspace(**teamspaces[requested_teamspace])


def upload_model(
path: str,
name: str,
progress_bar: bool = True,
cluster_id: Optional[str] = None,
) -> UploadedModelInfo:
"""Upload a local checkpoint file to the model store.

Args:
path: Path to the model file to upload.
name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
progress_bar: Whether to show a progress bar for the upload.
cluster_id: The name of the cluster to use. Only required if it can't be determined
automatically.

"""
org_name, teamspace_name, model_name = _parse_name(name)
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
return teamspace.upload_model(
path=path,
name=model_name,
progress_bar=progress_bar,
cluster_id=cluster_id,
)


def download_model(
name: str,
download_dir: Optional[str] = None,
progress_bar: bool = True,
) -> str:
"""Download a checkpoint from the model store.

Args:
name: Name tag of the model to download. Must be in the format 'organization/teamspace/modelname'
where entity is either your username or the name of an organization you are part of.
download_dir: A path to directory where the model should be downloaded. Defaults
to the current working directory.
progress_bar: Whether to show a progress bar for the download.

Returns:
The absolute path to the downloaded model file or folder.
"""
org_name, teamspace_name, model_name = _parse_name(name)
teamspace = _get_teamspace(name=teamspace_name, organization=org_name)
return teamspace.download_model(
name=model_name,
download_dir=download_dir,
progress_bar=progress_bar,
)
45 changes: 45 additions & 0 deletions tests/test_cloud_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from unittest import mock

import pytest
from litmodels.cloud_io import download_model, upload_model


@pytest.mark.parametrize("name", ["org/model", "model-name", "/too/many/slashes"])
def test_wrong_model_name(name):
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
upload_model(path="path/to/checkpoint", name=name)
with pytest.raises(ValueError, match=r".*organization/teamspace/model.*"):
download_model(name=name)


def test_upload_model(mocker):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)

# The lit-logger function is just a wrapper around the SDK function
upload_model(
path="path/to/checkpoint",
name="org-name/teamspace/model-name",
cluster_id="cluster_id",
)
ts_mock.upload_model.assert_called_once_with(
path="path/to/checkpoint",
name="model-name",
cluster_id="cluster_id",
progress_bar=True,
)


def test_download_model(mocker):
# mocking the _get_teamspace to return another mock
ts_mock = mock.MagicMock()
mocker.patch("litmodels.cloud_io._get_teamspace", return_value=ts_mock)
# The lit-logger function is just a wrapper around the SDK function
download_model(
name="org-name/teamspace/model-name",
download_dir="where/to/download",
)
ts_mock.download_model.assert_called_once_with(
name="model-name", download_dir="where/to/download", progress_bar=True
)
16 changes: 0 additions & 16 deletions tests/test_sample_module.py

This file was deleted.

Loading