Skip to content

Commit 8e902df

Browse files
Add support for resolving directories in /teamspace/lightning_storage (Lightning-AI#695)
* feat(resolver): add support for resolving directories in lightning storage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 42cbc98 commit 8e902df

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

src/litdata/streaming/resolver.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
8989
if dir_path_absolute.startswith("/teamspace/gcs_folders") and len(dir_path_absolute.split("/")) > 3:
9090
return _resolve_gcs_folders(dir_path_absolute)
9191

92+
if dir_path_absolute.startswith("/teamspace/lightning_storage") and len(dir_path_absolute.split("/")) > 3:
93+
return _resolve_lightning_storage(dir_path_absolute)
94+
9295
if dir_path_absolute.startswith("/teamspace/datasets") and len(dir_path_absolute.split("/")) > 3:
9396
return _resolve_datasets(dir_path_absolute)
9497

@@ -246,6 +249,28 @@ def _resolve_gcs_folders(dir_path: str) -> Dir:
246249
return Dir(path=dir_path, url=os.path.join(data_connection[0].gcs_folder.source, *dir_path.split("/")[4:]))
247250

248251

252+
def _resolve_lightning_storage(dir_path: str) -> Dir:
253+
from lightning_sdk.lightning_cloud.rest_client import LightningClient
254+
255+
client = LightningClient(max_tries=2)
256+
257+
# Get the ids from env variables
258+
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)
259+
if project_id is None:
260+
raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.")
261+
262+
target_name = dir_path.split("/")[3]
263+
264+
data_connections = client.data_connection_service_list_data_connections(project_id).data_connections
265+
266+
data_connection = [dc for dc in data_connections if dc.name == target_name]
267+
268+
if not data_connection:
269+
raise ValueError(f"We didn't find any matching data connection with the provided name `{target_name}`.")
270+
271+
return Dir(path=dir_path, url=os.path.join(data_connection[0].r2.source, *dir_path.split("/")[4:]))
272+
273+
249274
def _resolve_datasets(dir_path: str) -> Dir:
250275
from lightning_sdk.lightning_cloud.rest_client import LightningClient
251276

tests/streaming/test_resolver.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,45 @@ def test_src_resolver_gcs_folders(monkeypatch, lightning_cloud_mock):
465465
assert resolver._resolve_dir("/teamspace/gcs_folders/debug_folder/a/b/c").url == expected + "/a/b/c"
466466

467467
auth.clear()
468+
469+
470+
@pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported")
471+
def test_src_resolver_lightning_storage(monkeypatch, lightning_cloud_mock):
472+
"""Test lightning_storage resolver with r2 source."""
473+
auth = login.Auth()
474+
auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e")
475+
476+
with pytest.raises(
477+
RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables."
478+
):
479+
resolver._resolve_dir("/teamspace/lightning_storage/my_dataset")
480+
481+
monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id")
482+
483+
client_mock = mock.MagicMock()
484+
client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse(
485+
data_connections=[V1DataConnection(name="my_dataset", r2=mock.MagicMock(source="r2://my-r2-bucket"))],
486+
)
487+
488+
client_cls_mock = mock.MagicMock()
489+
client_cls_mock.return_value = client_mock
490+
lightning_cloud_mock.rest_client.LightningClient = client_cls_mock
491+
492+
expected = "r2://my-r2-bucket"
493+
assert resolver._resolve_dir("/teamspace/lightning_storage/my_dataset").url == expected
494+
assert resolver._resolve_dir("/teamspace/lightning_storage/my_dataset/train").url == expected + "/train"
495+
496+
# Test missing data connection
497+
client_mock = mock.MagicMock()
498+
client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse(
499+
data_connections=[],
500+
)
501+
502+
client_cls_mock = mock.MagicMock()
503+
client_cls_mock.return_value = client_mock
504+
lightning_cloud_mock.rest_client.LightningClient = client_cls_mock
505+
506+
with pytest.raises(ValueError, match="name `my_dataset`"):
507+
resolver._resolve_dir("/teamspace/lightning_storage/my_dataset")
508+
509+
auth.clear()

0 commit comments

Comments
 (0)