Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# NOTE: once we add more dependencies, consider update dependabot to check for updates

lightning-sdk >=0.1.35
lightning-sdk >=0.1.40
lightning-utilities
8 changes: 5 additions & 3 deletions src/litmodels/io/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# http://www.apache.org/licenses/LICENSE-2.0
#
from typing import Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Tuple, Union

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_sdk.lightning_cloud.env import LIGHTNING_CLOUD_URL
from lightning_sdk.teamspace import Teamspace
from lightning_sdk.utils import resolve as sdk_resolvers

if TYPE_CHECKING:
from lightning_sdk.models import UploadedModelInfo

# if module_available("lightning"):
# from lightning import LightningModule
# elif module_available("pytorch_lightning"):
Expand Down Expand Up @@ -81,7 +83,7 @@ def upload_model_files(
progress_bar: bool = True,
cluster_id: Optional[str] = None,
verbose: Union[bool, int] = 1,
) -> UploadedModelInfo:
) -> "UploadedModelInfo":
"""Upload a local checkpoint file to the model store.

Args:
Expand Down
8 changes: 5 additions & 3 deletions src/litmodels/io/gateway.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

from lightning_sdk.api.teamspace_api import UploadedModelInfo
from lightning_utilities import module_available

from litmodels.io.cloud import download_model_files, upload_model_files
Expand All @@ -14,6 +13,9 @@
else:
torch = None

if TYPE_CHECKING:
from lightning_sdk.models import UploadedModelInfo


def upload_model(
name: str,
Expand All @@ -22,7 +24,7 @@ def upload_model(
cluster_id: Optional[str] = None,
staging_dir: Optional[str] = None,
verbose: Union[bool, int] = 1,
) -> UploadedModelInfo:
) -> "UploadedModelInfo":
"""Upload a checkpoint to the model store.

Args:
Expand Down
Loading