Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,10 +681,7 @@ def _row_to_node(d: dict[str, Any]) -> Node:
for src in sources: # Opt: parallel
listing: Optional[Listing]
if src.startswith("ds://"):
ds_name, ds_version = parse_dataset_uri(src)
ds_namespace, ds_project, ds_name = parse_dataset_name(ds_name)
assert ds_namespace
assert ds_project
(ds_namespace, ds_project, ds_name, ds_version) = parse_dataset_uri(src)
dataset = self.get_dataset(
ds_name, namespace_name=ds_namespace, project_name=ds_project
)
Expand Down Expand Up @@ -1515,13 +1512,12 @@ def _instantiate(ds_uri: str) -> None:
studio_client = StudioClient()

try:
remote_ds_name, version = parse_dataset_uri(remote_ds_uri)
(remote_namespace, remote_project, remote_ds_name, version) = (
parse_dataset_uri(remote_ds_uri)
)
except Exception as e:
raise DataChainError("Error when parsing dataset uri") from e
Comment on lines 1514 to 1519
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Exception handling could be more specific than a generic Exception.

Catching Exception may hide unrelated errors. Catch ValueError or a custom exception from parse_dataset_uri instead.

Suggested change
try:
remote_ds_name, version = parse_dataset_uri(remote_ds_uri)
(remote_namespace, remote_project, remote_ds_name, version) = (
parse_dataset_uri(remote_ds_uri)
)
except Exception as e:
raise DataChainError("Error when parsing dataset uri") from e
try:
(remote_namespace, remote_project, remote_ds_name, version) = (
parse_dataset_uri(remote_ds_uri)
)
except ValueError as e:
raise DataChainError("Error when parsing dataset uri") from e


remote_namespace, remote_project, remote_ds_name = parse_dataset_name(
remote_ds_name
)
if not remote_namespace or not remote_project:
raise DataChainError(
f"Invalid fully qualified dataset name {remote_ds_name}, namespace"
Expand Down
45 changes: 27 additions & 18 deletions src/datachain/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
import json
import re
from dataclasses import dataclass, fields
from datetime import datetime
from functools import cached_property
Expand All @@ -10,7 +11,6 @@
TypeVar,
Union,
)
from urllib.parse import urlparse

from packaging.specifiers import SpecifierSet
from packaging.version import Version
Expand Down Expand Up @@ -43,25 +43,34 @@
StorageURI = NewType("StorageURI", str)


def parse_dataset_uri(uri: str) -> tuple[str, Optional[str]]:
def parse_dataset_uri(
uri: str,
) -> tuple[Optional[str], Optional[str], str, Optional[str]]:
"""
Parse dataser uri to extract name and version out of it (if version is defined)
Example:
Input: ds://[email protected]
Output: (zalando, 3.0.1)
Parse a dataset URI of the form:

ds://[<namespace>.][<project>.]<name>[@v<semver>]

Returns:
(namespace, project, name, version)
"""
p = urlparse(uri)
if p.scheme != "ds":
raise Exception("Dataset uri should start with ds://")
s = p.netloc.split("@v")
name = s[0]
if len(s) == 1:
return name, None
if len(s) != 2:
raise Exception(
"Wrong dataset uri format, it should be: ds://<name>@v<version>"
)
return name, s[1]

if not uri.startswith("ds://"):
raise ValueError(f"Invalid dataset URI: {uri}")

body = uri[len("ds://") :]

# Split off optional @v<version>
match = re.match(r"^(?P<name>.+?)(?:@v(?P<version>\d+\.\d+\.\d+))?$", body)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still have quite a lot of cases where this regexp will not works (and we have no tests for all these cases), I would prefer to move parsing dataset name, version, namespace, project, etc to separate module and cover it all with tests, so we can reuse it all over the code, but same time it looks like a separate task to refactor this part and it is good enough for now for me.

if not match:
raise ValueError(f"Invalid dataset URI format: {uri}")

dataset_name = match.group("name")
version = match.group("version")

namespace, project, name = parse_dataset_name(dataset_name)

return namespace, project, name, version


def create_dataset_uri(
Expand Down
4 changes: 2 additions & 2 deletions tests/func/test_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def test_pull_dataset_wrong_version(

with pytest.raises(DataChainError) as exc_info:
catalog.pull_dataset(
f"ds://{REMOTE_NAMESPACE_NAME}.{REMOTE_PROJECT_NAME}.dogs@v5"
f"ds://{REMOTE_NAMESPACE_NAME}.{REMOTE_PROJECT_NAME}.dogs@v5.0.0"
)
assert str(exc_info.value) == "Dataset dogs doesn't have version 5 on server"
assert str(exc_info.value) == "Dataset dogs doesn't have version 5.0.0 on server"


@pytest.mark.parametrize("cloud_type, version_aware", [("s3", False)], indirect=True)
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DatasetRecord,
DatasetVersion,
parse_dataset_name,
parse_dataset_uri,
)
from datachain.error import InvalidDatasetNameError
from datachain.sql.types import (
Expand Down Expand Up @@ -180,3 +181,24 @@ def test_parse_dataset_name(full_name, namespace, project, name):
def test_parse_dataset_name_empty_name():
with pytest.raises(InvalidDatasetNameError):
assert parse_dataset_name(None)


@pytest.mark.parametrize(
"uri,namespace,project,name,version",
[
("ds://result", None, None, "result", None),
("ds://[email protected]", None, None, "result", "1.0.5"),
("ds://dev.result", None, "dev", "result", None),
("ds://[email protected]", None, "dev", "result", "1.0.5"),
("ds://global.dev.result", "global", "dev", "result", None),
("ds://[email protected]", "global", "dev", "result", "1.0.5"),
("ds://@ilongin.dev.result", "@ilongin", "dev", "result", None),
("ds://@[email protected]", "@ilongin", "dev", "result", "1.0.4"),
("ds://@vlad.dev.result", "@vlad", "dev", "result", None),
("ds://@[email protected]", "@vlad", "dev", "result", "1.0.5"),
("ds://@[email protected]@v1.0.5", "@vlad", "@vlad", "result", "1.0.5"),
("ds://@vlad.@vlad.@[email protected]", "@vlad", "@vlad", "@vlad", "1.0.5"),
],
)
def test_parse_dataset_uri(uri, namespace, project, name, version):
assert parse_dataset_uri(uri) == (namespace, project, name, version)
Loading