Skip to content

Commit d01e5bf

Browse files
fix: add support for wildcard pattern in seed dataset path (#12)
* add IndexRange and PartitionBlock * make check-all-fix * add support for IndexRange and PartitionBlock * linting * update tests + test e2e * run ruff * license check header * partition_index -> index * update log message * Remove sub subquery alias notneeded * Optimize duckdb seed dataset select based on on limit and offset * Add docstring to seedconfig * add guide * feat req updates * git branch pattern update * Update CONTRIBUTING.md Co-authored-by: Nabin Mulepati <[email protected]> * agent md blurb * pr feedback * some rewording * punctuation * missing quote * add IndexRange and PartitionBlock * make check-all-fix * add support for IndexRange and PartitionBlock * linting * update tests + test e2e * run ruff * license check header * partition_index -> index * update log message * Remove sub subquery alias notneeded * Optimize duckdb seed dataset select based on on limit and offset * Add docstring to seedconfig * Support wildcard path pattern for seed dataset * fix check during column name fetch * dump json as jsonl * PR feedback --------- Co-authored-by: Johnny Greco <[email protected]>
1 parent 7268290 commit d01e5bf

File tree

5 files changed

+116
-15
lines changed

5 files changed

+116
-15
lines changed

src/data_designer/config/datastore.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pydantic import BaseModel, Field
1414

1515
from .errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError
16-
from .utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS
16+
from .utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS, validate_path_contains_files_of_type
1717

1818
if TYPE_CHECKING:
1919
from .seed import SeedDatasetReference
@@ -32,7 +32,15 @@ class DatastoreSettings(BaseModel):
3232

3333

3434
def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]:
35-
"""Extract column names based on file type."""
35+
"""Extract column names based on file type. Supports glob patterns like '../path/*.parquet'."""
36+
file_path = Path(file_path)
37+
if "*" in str(file_path):
38+
matching_files = sorted(file_path.parent.glob(file_path.name))
39+
if not matching_files:
40+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
41+
logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
42+
file_path = matching_files[0]
43+
3644
if file_type == "parquet":
3745
try:
3846
schema = pq.read_schema(file_path)
@@ -123,11 +131,14 @@ def _fetch_seed_dataset_column_names_from_datastore(
123131

124132

125133
def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
126-
dataset_path = _validate_dataset_path(dataset_path)
127-
return get_file_column_names(dataset_path, dataset_path.suffix.lower()[1:])
134+
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
135+
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
128136

129137

130-
def _validate_dataset_path(dataset_path: Union[str, Path]) -> Path:
138+
def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
139+
if allow_glob_pattern and "*" in str(dataset_path):
140+
validate_path_contains_files_of_type(dataset_path, str(dataset_path).split(".")[-1])
141+
return Path(dataset_path)
131142
if not Path(dataset_path).is_file():
132143
raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.")
133144
if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):

src/data_designer/config/seed.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
from .base import ConfigBase
1212
from .datastore import DatastoreSettings
13-
from .utils.io_helpers import validate_dataset_file_path
13+
from .utils.io_helpers import (
14+
VALID_DATASET_FILE_EXTENSIONS,
15+
validate_dataset_file_path,
16+
validate_path_contains_files_of_type,
17+
)
1418

1519

1620
class SamplingStrategy(str, Enum):
@@ -130,4 +134,12 @@ def filename(self) -> str:
130134
class LocalSeedDatasetReference(SeedDatasetReference):
131135
@field_validator("dataset", mode="after")
132136
def validate_dataset_is_file(cls, v: str) -> str:
133-
return str(validate_dataset_file_path(v))
137+
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
138+
if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
139+
parts = v.split("*.")
140+
file_path = parts[0]
141+
file_extension = parts[-1]
142+
validate_path_contains_files_of_type(file_path, file_extension)
143+
else:
144+
validate_dataset_file_path(v)
145+
return v

src/data_designer/config/utils/io_helpers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool =
6969
Args:
7070
file_path: The path to validate, either as a string or Path object.
7171
should_exist: If True, verify that the file exists. Defaults to True.
72-
7372
Returns:
7473
The validated path as a Path object.
74+
Raises:
75+
InvalidFilePathError: If the path is not a file.
76+
InvalidFileFormatError: If the path does not have a valid extension.
7577
"""
7678
file_path = Path(file_path)
7779
if should_exist and not Path(file_path).is_file():
@@ -83,6 +85,21 @@ def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool =
8385
return file_path
8486

8587

88+
def validate_path_contains_files_of_type(path: str | Path, file_extension: str) -> None:
89+
"""Validate that a path contains files of a specific type.
90+
91+
Args:
92+
path: The path to validate. Can contain wildcards like `*.parquet`.
93+
file_extension: The extension of the files to validate (without the dot, e.g., "parquet").
94+
Returns:
95+
None if the path contains files of the specified type, raises an error otherwise.
96+
Raises:
97+
InvalidFilePathError: If the path does not contain files of the specified type.
98+
"""
99+
if not any(Path(path).glob(f"*.{file_extension}")):
100+
raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
101+
102+
86103
def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame:
87104
"""Load a dataframe from file if a path is given, otherwise return the dataframe.
88105

tests/config/test_datastore.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _write_file(df, path, file_type):
3636

3737
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
3838
def test_get_file_column_names_basic_parquet(tmp_path, file_type):
39-
"""Test _get_file_column_names with basic parquet file."""
39+
"""Test get_file_column_names with basic parquet file."""
4040
test_data = {
4141
"id": [1, 2, 3],
4242
"name": ["Alice", "Bob", "Charlie"],
@@ -51,7 +51,7 @@ def test_get_file_column_names_basic_parquet(tmp_path, file_type):
5151

5252

5353
def test_get_file_column_names_nested_fields(tmp_path):
54-
"""Test _get_file_column_names with nested fields in parquet."""
54+
"""Test get_file_column_names with nested fields in parquet."""
5555
schema = pa.schema(
5656
[
5757
pa.field(
@@ -72,7 +72,7 @@ def test_get_file_column_names_nested_fields(tmp_path):
7272

7373
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
7474
def test_get_file_column_names_empty_parquet(tmp_path, file_type):
75-
"""Test _get_file_column_names with empty parquet file."""
75+
"""Test get_file_column_names with empty parquet file."""
7676
empty_df = pd.DataFrame()
7777
empty_path = tmp_path / f"empty.{file_type}"
7878
_write_file(empty_df, empty_path, file_type)
@@ -83,7 +83,7 @@ def test_get_file_column_names_empty_parquet(tmp_path, file_type):
8383

8484
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
8585
def test_get_file_column_names_large_schema(tmp_path, file_type):
86-
"""Test _get_file_column_names with many columns."""
86+
"""Test get_file_column_names with many columns."""
8787
num_columns = 50
8888
test_data = {f"col_{i}": np.random.randn(10) for i in range(num_columns)}
8989
df = pd.DataFrame(test_data)
@@ -98,7 +98,7 @@ def test_get_file_column_names_large_schema(tmp_path, file_type):
9898

9999
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
100100
def test_get_file_column_names_special_characters(tmp_path, file_type):
101-
"""Test _get_file_column_names with special characters in column names."""
101+
"""Test get_file_column_names with special characters in column names."""
102102
special_data = {
103103
"column with spaces": [1],
104104
"column-with-dashes": [2],
@@ -117,7 +117,7 @@ def test_get_file_column_names_special_characters(tmp_path, file_type):
117117

118118
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
119119
def test_get_file_column_names_unicode(tmp_path, file_type):
120-
"""Test _get_file_column_names with unicode column names."""
120+
"""Test get_file_column_names with unicode column names."""
121121
unicode_data = {"café": [1], "résumé": [2], "naïve": [3], "façade": [4], "garçon": [5], "über": [6], "schön": [7]}
122122
df_unicode = pd.DataFrame(unicode_data)
123123

@@ -126,6 +126,22 @@ def test_get_file_column_names_unicode(tmp_path, file_type):
126126
assert get_file_column_names(str(unicode_path), file_type) == df_unicode.columns.tolist()
127127

128128

129+
@pytest.mark.parametrize("file_type", ["parquet", "csv", "json", "jsonl"])
130+
def test_get_file_column_names_with_glob_pattern(tmp_path, file_type):
131+
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
132+
for i in range(5):
133+
_write_file(df, tmp_path / f"{i}.{file_type}", file_type)
134+
assert get_file_column_names(f"{tmp_path}/*.{file_type}", file_type) == ["col1", "col2"]
135+
136+
137+
def test_get_file_column_names_with_glob_pattern_error(tmp_path):
138+
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
139+
for i in range(5):
140+
_write_file(df, tmp_path / f"{i}.parquet", "parquet")
141+
with pytest.raises(InvalidFilePathError, match="No files found matching pattern"):
142+
get_file_column_names(f"{tmp_path}/*.csv", "csv")
143+
144+
129145
def test_get_file_column_names_error_handling():
130146
with pytest.raises(InvalidFilePathError, match="🛑 Unsupported file type: 'txt'"):
131147
get_file_column_names("test.txt", "txt")

tests/config/test_seed.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,29 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from pathlib import Path
5+
6+
import pandas as pd
47
import pytest
58

6-
from data_designer.config.seed import IndexRange, PartitionBlock
9+
from data_designer.config.errors import InvalidFilePathError
10+
from data_designer.config.seed import IndexRange, LocalSeedDatasetReference, PartitionBlock
11+
12+
13+
def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path:
14+
df = pd.DataFrame({"col": [1, 2, 3]})
15+
16+
for i in range(num_files):
17+
file_path = temp_dir / f"partition_{i}.{extension}"
18+
if extension == "parquet":
19+
df.to_parquet(file_path)
20+
elif extension == "csv":
21+
df.to_csv(file_path, index=False)
22+
elif extension == "json":
23+
df.to_json(file_path, orient="records", lines=True)
24+
elif extension == "jsonl":
25+
df.to_json(file_path, orient="records", lines=True)
26+
return temp_dir
727

828

929
def test_index_range_validation():
@@ -54,3 +74,28 @@ def test_partition_block_to_index_range():
5474
assert index_range.start == 90
5575
assert index_range.end == 104
5676
assert index_range.size == 15
77+
78+
79+
def test_local_seed_dataset_reference_validation(tmp_path: Path):
80+
with pytest.raises(InvalidFilePathError, match="🛑 Path test/dataset.parquet is not a file."):
81+
LocalSeedDatasetReference(dataset="test/dataset.parquet")
82+
83+
# Should not raise an error when referencing supported extensions with wildcard pattern.
84+
create_partitions_in_path(tmp_path, "parquet")
85+
create_partitions_in_path(tmp_path, "csv")
86+
create_partitions_in_path(tmp_path, "json")
87+
create_partitions_in_path(tmp_path, "jsonl")
88+
89+
test_cases = ["parquet", "csv", "json", "jsonl"]
90+
try:
91+
for extension in test_cases:
92+
reference = LocalSeedDatasetReference(dataset=f"{tmp_path}/*.{extension}")
93+
assert reference.dataset == f"{tmp_path}/*.{extension}"
94+
except Exception as e:
95+
pytest.fail(f"Expected no exception, but got {e}")
96+
97+
98+
def test_local_seed_dataset_reference_validation_error(tmp_path: Path):
99+
create_partitions_in_path(tmp_path, "parquet")
100+
with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"):
101+
LocalSeedDatasetReference(dataset=f"{tmp_path}/*.csv")

0 commit comments

Comments
 (0)