Skip to content

Commit b815179

Browse files
committed
Support wildcard path pattern for seed dataset
1 parent eb6614c commit b815179

File tree

3 files changed

+86
-4
lines changed

3 files changed

+86
-4
lines changed

src/data_designer/config/seed.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
from .base import ConfigBase
1212
from .datastore import DatastoreSettings
13-
from .utils.io_helpers import validate_dataset_file_path
13+
from .utils.io_helpers import (
14+
VALID_DATASET_FILE_EXTENSIONS,
15+
validate_dataset_file_path,
16+
validate_path_contains_files_of_type,
17+
)
1418

1519

1620
class SamplingStrategy(str, Enum):
@@ -130,4 +134,12 @@ def filename(self) -> str:
130134
class LocalSeedDatasetReference(SeedDatasetReference):
131135
@field_validator("dataset", mode="after")
132136
def validate_dataset_is_file(cls, v: str) -> str:
133-
return str(validate_dataset_file_path(v))
137+
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
138+
if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
139+
parts = v.split("*.")
140+
file_path = parts[0]
141+
file_extension = parts[-1]
142+
validate_path_contains_files_of_type(file_path, file_extension)
143+
else:
144+
validate_dataset_file_path(v)
145+
return v

src/data_designer/config/utils/io_helpers.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@ def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True)
6868
Args:
6969
file_path: The path to validate, either as a string or Path object.
7070
should_exist: If True, verify that the file exists. Defaults to True.
71-
7271
Returns:
7372
The validated path as a Path object.
73+
Raises:
74+
InvalidFilePathError: If the path is not a file.
75+
InvalidFileFormatError: If the path does not have a valid extension.
7476
"""
7577
file_path = Path(file_path)
7678
if should_exist and not Path(file_path).is_file():
@@ -82,6 +84,21 @@ def validate_dataset_file_path(file_path: str | Path, should_exist: bool = True)
8284
return file_path
8385

8486

87+
def validate_path_contains_files_of_type(path: str | Path, file_extension: str) -> None:
88+
"""Validate that a path contains files of a specific type.
89+
90+
Args:
91+
path: The path to validate. Can contain wildcards like `*.parquet`.
92+
file_extension: The extension of the files to validate (without the dot, e.g., "parquet").
93+
Returns:
94+
None if the path contains files of the specified type, raises an error otherwise.
95+
Raises:
96+
InvalidFilePathError: If the path does not contain files of the specified type.
97+
"""
98+
if not any(Path(path).glob(f"*.{file_extension}")):
99+
raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")
100+
101+
85102
def smart_load_dataframe(dataframe: str | Path | pd.DataFrame) -> pd.DataFrame:
86103
"""Load a dataframe from file if a path is given, otherwise return the dataframe.
87104

tests/config/test_seed.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,30 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from pathlib import Path
5+
import tempfile
6+
7+
import pandas as pd
48
import pytest
59

6-
from data_designer.config.seed import IndexRange, PartitionBlock
10+
from data_designer.config.errors import InvalidFilePathError
11+
from data_designer.config.seed import IndexRange, LocalSeedDatasetReference, PartitionBlock
12+
13+
14+
def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path:
15+
df = pd.DataFrame({"col": [1, 2, 3]})
16+
17+
for i in range(num_files):
18+
file_path = temp_dir / f"partition_{i}.{extension}"
19+
if extension == "parquet":
20+
df.to_parquet(file_path)
21+
elif extension == "csv":
22+
df.to_csv(file_path, index=False)
23+
elif extension == "json":
24+
df.to_json(file_path, orient="records")
25+
elif extension == "jsonl":
26+
df.to_json(file_path, orient="records", lines=True)
27+
return temp_dir
728

829

930
def test_index_range_validation():
@@ -54,3 +75,35 @@ def test_partition_block_to_index_range():
5475
assert index_range.start == 90
5576
assert index_range.end == 104
5677
assert index_range.size == 15
78+
79+
80+
def test_local_seed_dataset_reference_validation():
81+
with pytest.raises(InvalidFilePathError, match="🛑 Path test/dataset.parquet is not a file."):
82+
LocalSeedDatasetReference(dataset="test/dataset.parquet")
83+
84+
# Should not raise an error when referencing supported extensions with wildcard pattern.
85+
with tempfile.TemporaryDirectory() as temp_dir:
86+
create_partitions_in_path(Path(temp_dir), "parquet")
87+
create_partitions_in_path(Path(temp_dir), "csv")
88+
create_partitions_in_path(Path(temp_dir), "json")
89+
create_partitions_in_path(Path(temp_dir), "jsonl")
90+
91+
test_cases = [
92+
(temp_dir, "parquet"),
93+
(temp_dir, "csv"),
94+
(temp_dir, "json"),
95+
(temp_dir, "jsonl"),
96+
]
97+
98+
try:
99+
for temp_dir, extension in test_cases:
100+
reference = LocalSeedDatasetReference(dataset=f"{temp_dir}/*.{extension}")
101+
assert reference.dataset == f"{temp_dir}/*.{extension}"
102+
except Exception as e:
103+
pytest.fail(f"Expected no exception, but got {e}")
104+
105+
# Should raise an error when referencing a path that does not contain files of the specified type.
106+
with tempfile.TemporaryDirectory() as temp_dir:
107+
create_partitions_in_path(Path(temp_dir), "parquet")
108+
with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"):
109+
LocalSeedDatasetReference(dataset=f"{temp_dir}/*.csv")

0 commit comments

Comments
 (0)