Skip to content

Commit b9e60b4

Browse files
authored
storage: index as a dir if no glob (#108)
* storage: index as a dir if no glob * add tests for enlist source on globs, dirs, file paths
1 parent e5be7e6 commit b9e60b4

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

src/datachain/catalog/catalog.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import glob
23
import io
34
import json
45
import logging
@@ -709,7 +710,12 @@ def enlist_source(
709710

710711
client_config = client_config or self.client_config
711712
client, path = self.parse_url(source, **client_config)
712-
prefix = posixpath.dirname(path)
713+
stem = os.path.basename(os.path.normpath(path))
714+
prefix = (
715+
posixpath.dirname(path)
716+
if glob.has_magic(stem) or client.fs.isfile(source)
717+
else path
718+
)
713719
storage_dataset_name = Storage.dataset_name(
714720
client.uri, posixpath.join(prefix, "")
715721
)

tests/func/test_catalog.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,45 @@ def test_storage_stats(cloud_test_catalog):
10681068
assert stats.size == 15
10691069

10701070

1071+
@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
1072+
def test_enlist_source_handles_slash(cloud_test_catalog):
1073+
catalog = cloud_test_catalog.catalog
1074+
src_uri = cloud_test_catalog.src_uri
1075+
1076+
catalog.enlist_source(f"{src_uri}/dogs", ttl=1234)
1077+
stats = catalog.storage_stats(src_uri)
1078+
assert stats.num_objects == len(DEFAULT_TREE["dogs"])
1079+
assert stats.size == 15
1080+
1081+
catalog.enlist_source(f"{src_uri}/dogs/", ttl=1234, force_update=True)
1082+
stats = catalog.storage_stats(src_uri)
1083+
assert stats.num_objects == len(DEFAULT_TREE["dogs"])
1084+
assert stats.size == 15
1085+
1086+
1087+
@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
1088+
def test_enlist_source_handles_glob(cloud_test_catalog):
1089+
catalog = cloud_test_catalog.catalog
1090+
src_uri = cloud_test_catalog.src_uri
1091+
1092+
catalog.enlist_source(f"{src_uri}/dogs/*.jpg", ttl=1234)
1093+
stats = catalog.storage_stats(src_uri)
1094+
1095+
assert stats.num_objects == len(DEFAULT_TREE["dogs"])
1096+
assert stats.size == 15
1097+
1098+
1099+
@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True)
1100+
def test_enlist_source_handles_file(cloud_test_catalog):
1101+
catalog = cloud_test_catalog.catalog
1102+
src_uri = cloud_test_catalog.src_uri
1103+
1104+
catalog.enlist_source(f"{src_uri}/dogs/dog1", ttl=1234)
1105+
stats = catalog.storage_stats(src_uri)
1106+
assert stats.num_objects == len(DEFAULT_TREE["dogs"])
1107+
assert stats.size == 15
1108+
1109+
10711110
@pytest.mark.parametrize("from_cli", [False, True])
10721111
def test_garbage_collect(cloud_test_catalog, from_cli, capsys):
10731112
catalog = cloud_test_catalog.catalog

tests/unit/lib/test_datachain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def test_parse_tabular_partitions(tmp_dir, catalog):
660660

661661
def test_parse_tabular_empty(tmp_dir, catalog):
662662
path = tmp_dir / "test.parquet"
663-
with pytest.raises(DataChainParamsError):
663+
with pytest.raises(FileNotFoundError):
664664
DataChain.from_storage(path.as_uri()).parse_tabular()
665665

666666

0 commit comments

Comments
 (0)