diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 576bc5447..3523fc0cb 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -681,10 +681,7 @@ def _row_to_node(d: dict[str, Any]) -> Node: for src in sources: # Opt: parallel listing: Optional[Listing] if src.startswith("ds://"): - ds_name, ds_version = parse_dataset_uri(src) - ds_namespace, ds_project, ds_name = parse_dataset_name(ds_name) - assert ds_namespace - assert ds_project + (ds_namespace, ds_project, ds_name, ds_version) = parse_dataset_uri(src) dataset = self.get_dataset( ds_name, namespace_name=ds_namespace, project_name=ds_project ) @@ -1515,13 +1512,12 @@ def _instantiate(ds_uri: str) -> None: studio_client = StudioClient() try: - remote_ds_name, version = parse_dataset_uri(remote_ds_uri) + (remote_namespace, remote_project, remote_ds_name, version) = ( + parse_dataset_uri(remote_ds_uri) + ) except Exception as e: raise DataChainError("Error when parsing dataset uri") from e - remote_namespace, remote_project, remote_ds_name = parse_dataset_name( - remote_ds_name - ) if not remote_namespace or not remote_project: raise DataChainError( f"Invalid fully qualified dataset name {remote_ds_name}, namespace" diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index 01d733d86..175c470b3 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -1,5 +1,6 @@ import builtins import json +import re from dataclasses import dataclass, fields from datetime import datetime from functools import cached_property @@ -10,7 +11,6 @@ TypeVar, Union, ) -from urllib.parse import urlparse from packaging.specifiers import SpecifierSet from packaging.version import Version @@ -43,25 +43,34 @@ StorageURI = NewType("StorageURI", str) -def parse_dataset_uri(uri: str) -> tuple[str, Optional[str]]: +def parse_dataset_uri( + uri: str, +) -> tuple[Optional[str], Optional[str], str, Optional[str]]: """ - Parse dataser uri to extract name and version out of it (if version is defined) - Example: - Input: ds://zalando@v3.0.1 - Output: (zalando, 3.0.1) + Parse a dataset URI of the form: + + ds://[.][.][@v] + + Returns: + (namespace, project, name, version) """ - p = urlparse(uri) - if p.scheme != "ds": - raise Exception("Dataset uri should start with ds://") - s = p.netloc.split("@v") - name = s[0] - if len(s) == 1: - return name, None - if len(s) != 2: - raise Exception( - "Wrong dataset uri format, it should be: ds://@v" - ) - return name, s[1] + + if not uri.startswith("ds://"): + raise ValueError(f"Invalid dataset URI: {uri}") + + body = uri[len("ds://") :] + + # Split off optional @v + match = re.match(r"^(?P.+?)(?:@v(?P\d+\.\d+\.\d+))?$", body) + if not match: + raise ValueError(f"Invalid dataset URI format: {uri}") + + dataset_name = match.group("name") + version = match.group("version") + + namespace, project, name = parse_dataset_name(dataset_name) + + return namespace, project, name, version def create_dataset_uri( diff --git a/tests/func/test_pull.py b/tests/func/test_pull.py index aec06841e..c63609b7e 100644 --- a/tests/func/test_pull.py +++ b/tests/func/test_pull.py @@ -286,9 +286,9 @@ def test_pull_dataset_wrong_version( with pytest.raises(DataChainError) as exc_info: catalog.pull_dataset( - f"ds://{REMOTE_NAMESPACE_NAME}.{REMOTE_PROJECT_NAME}.dogs@v5" + f"ds://{REMOTE_NAMESPACE_NAME}.{REMOTE_PROJECT_NAME}.dogs@v5.0.0" ) - assert str(exc_info.value) == "Dataset dogs doesn't have version 5 on server" + assert str(exc_info.value) == "Dataset dogs doesn't have version 5.0.0 on server" @pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True) diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index 5e96ee8f4..4c5336974 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -13,6 +13,7 @@ DatasetRecord, DatasetVersion, parse_dataset_name, + parse_dataset_uri, ) from datachain.error import InvalidDatasetNameError from datachain.sql.types import ( @@ -180,3 +181,24 @@ def test_parse_dataset_name(full_name, namespace, project, name): def test_parse_dataset_name_empty_name(): with pytest.raises(InvalidDatasetNameError): assert parse_dataset_name(None) + + +@pytest.mark.parametrize( + "uri,namespace,project,name,version", + [ + ("ds://result", None, None, "result", None), + ("ds://result@v1.0.5", None, None, "result", "1.0.5"), + ("ds://dev.result", None, "dev", "result", None), + ("ds://dev.result@v1.0.5", None, "dev", "result", "1.0.5"), + ("ds://global.dev.result", "global", "dev", "result", None), + ("ds://global.dev.result@v1.0.5", "global", "dev", "result", "1.0.5"), + ("ds://@ilongin.dev.result", "@ilongin", "dev", "result", None), + ("ds://@ilongin.dev.result@v1.0.4", "@ilongin", "dev", "result", "1.0.4"), + ("ds://@vlad.dev.result", "@vlad", "dev", "result", None), + ("ds://@vlad.dev.result@v1.0.5", "@vlad", "dev", "result", "1.0.5"), + ("ds://@vlad.@vlad.result@v1.0.5", "@vlad", "@vlad", "result", "1.0.5"), + ("ds://@vlad.@vlad.@vlad@v1.0.5", "@vlad", "@vlad", "@vlad", "1.0.5"), + ], +) +def test_parse_dataset_uri(uri, namespace, project, name, version): + assert parse_dataset_uri(uri) == (namespace, project, name, version)