Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b429464
add IndexRange and PartitionBlock
nabinchha Nov 3, 2025
0d4b9a0
make check-all-fix
nabinchha Nov 3, 2025
57d500b
add support for IndexRange and PartitionBlock
nabinchha Nov 3, 2025
6bd7760
linting
nabinchha Nov 3, 2025
12301b3
update tests + test e2e
nabinchha Nov 4, 2025
03c048e
run ruff
nabinchha Nov 4, 2025
6da2c3c
license check header
nabinchha Nov 4, 2025
80fca51
partition_index -> index
nabinchha Nov 4, 2025
4e3bd3e
update log message
nabinchha Nov 4, 2025
c32b278
Remove sub subquery alias notneeded
nabinchha Nov 4, 2025
483363d
Optimize duckdb seed dataset select based on on limit and offset
nabinchha Nov 4, 2025
9765643
Add docstring to seedconfig
nabinchha Nov 4, 2025
98993de
add guide
johnnygreco Oct 30, 2025
b4be821
feat req updates
johnnygreco Oct 31, 2025
e9f97d6
git branch pattern update
johnnygreco Oct 31, 2025
5765301
Update CONTRIBUTING.md
johnnygreco Oct 31, 2025
2c38781
agent md blurb
johnnygreco Oct 31, 2025
f405dbd
pr feedback
johnnygreco Oct 31, 2025
13f9527
some rewording
johnnygreco Nov 3, 2025
c84a70b
punctuation
johnnygreco Nov 3, 2025
17657e6
missing quote
johnnygreco Nov 3, 2025
c00953b
add IndexRange and PartitionBlock
nabinchha Nov 3, 2025
8f2d502
make check-all-fix
nabinchha Nov 3, 2025
e854b06
add support for IndexRange and PartitionBlock
nabinchha Nov 3, 2025
ff9202f
linting
nabinchha Nov 3, 2025
371365b
update tests + test e2e
nabinchha Nov 4, 2025
8965096
run ruff
nabinchha Nov 4, 2025
3c03626
license check header
nabinchha Nov 4, 2025
7fa8904
partition_index -> index
nabinchha Nov 4, 2025
a76a21b
update log message
nabinchha Nov 4, 2025
4e0d86c
Remove sub subquery alias notneeded
nabinchha Nov 4, 2025
28567d1
Optimize duckdb seed dataset select based on on limit and offset
nabinchha Nov 4, 2025
eb6614c
Add docstring to seedconfig
nabinchha Nov 4, 2025
b815179
Support wildcard path pattern for seed dataset
nabinchha Nov 4, 2025
6b14c31
Merge branch 'main' into nm/seed-config-partition-strategy
nabinchha Nov 4, 2025
710bd69
Merge branch 'nm/seed-config-partition-strategy' into nabinchha/bug/2…
nabinchha Nov 4, 2025
4e6281a
Merge branch 'main' into nabinchha/bug/2-support-seed-path-with-parti…
nabinchha Nov 4, 2025
690b00f
fix check during column name fetch
nabinchha Nov 5, 2025
75d556c
dump json as jsonl
nabinchha Nov 5, 2025
f684e1c
PR feedback
nabinchha Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/data_designer/config/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic import BaseModel, Field

from .errors import InvalidConfigError, InvalidFileFormatError, InvalidFilePathError
from .utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS
from .utils.io_helpers import VALID_DATASET_FILE_EXTENSIONS, validate_path_contains_files_of_type

if TYPE_CHECKING:
from .seed import SeedDatasetReference
Expand All @@ -32,7 +32,15 @@ class DatastoreSettings(BaseModel):


def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]:
"""Extract column names based on file type."""
"""Extract column names based on file type. Supports glob patterns like '../path/*.parquet'."""
file_path = Path(file_path)
if "*" in str(file_path):
matching_files = sorted(file_path.parent.glob(file_path.name))
if not matching_files:
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
file_path = matching_files[0]

if file_type == "parquet":
try:
schema = pq.read_schema(file_path)
Expand Down Expand Up @@ -123,11 +131,14 @@ def _fetch_seed_dataset_column_names_from_datastore(


def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
dataset_path = _validate_dataset_path(dataset_path)
return get_file_column_names(dataset_path, dataset_path.suffix.lower()[1:])
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])


def _validate_dataset_path(dataset_path: Union[str, Path]) -> Path:
def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
if allow_glob_pattern and "*" in str(dataset_path):
validate_path_contains_files_of_type(dataset_path, str(dataset_path).split(".")[-1])
return Path(dataset_path)
if not Path(dataset_path).is_file():
raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.")
if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):
Expand Down
16 changes: 14 additions & 2 deletions src/data_designer/config/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from .base import ConfigBase
from .datastore import DatastoreSettings
from .utils.io_helpers import validate_dataset_file_path
from .utils.io_helpers import (
VALID_DATASET_FILE_EXTENSIONS,
validate_dataset_file_path,
validate_path_contains_files_of_type,
)


class SamplingStrategy(str, Enum):
Expand Down Expand Up @@ -130,4 +134,12 @@ def filename(self) -> str:
class LocalSeedDatasetReference(SeedDatasetReference):
@field_validator("dataset", mode="after")
def validate_dataset_is_file(cls, v: str) -> str:
return str(validate_dataset_file_path(v))
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
parts = v.split("*.")
file_path = parts[0]
file_extension = parts[-1]
validate_path_contains_files_of_type(file_path, file_extension)
else:
validate_dataset_file_path(v)
return v
19 changes: 18 additions & 1 deletion src/data_designer/config/utils/io_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool =
Args:
file_path: The path to validate, either as a string or Path object.
should_exist: If True, verify that the file exists. Defaults to True.

Returns:
The validated path as a Path object.
Raises:
InvalidFilePathError: If the path is not a file.
InvalidFileFormatError: If the path does not have a valid extension.
"""
file_path = Path(file_path)
if should_exist and not Path(file_path).is_file():
Expand All @@ -83,6 +85,21 @@ def validate_dataset_file_path(file_path: Union[str, Path], should_exist: bool =
return file_path


def validate_path_contains_files_of_type(path: str | Path, file_extension: str) -> None:
"""Validate that a path contains files of a specific type.

Args:
path: The path to validate. Can contain wildcards like `*.parquet`.
file_extension: The extension of the files to validate (without the dot, e.g., "parquet").
Returns:
None if the path contains files of the specified type, raises an error otherwise.
Raises:
InvalidFilePathError: If the path does not contain files of the specified type.
"""
if not any(Path(path).glob(f"*.{file_extension}")):
raise InvalidFilePathError(f"🛑 Path {path!r} does not contain files of type {file_extension!r}.")


def smart_load_dataframe(dataframe: Union[str, Path, pd.DataFrame]) -> pd.DataFrame:
"""Load a dataframe from file if a path is given, otherwise return the dataframe.

Expand Down
28 changes: 22 additions & 6 deletions tests/config/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _write_file(df, path, file_type):

@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
def test_get_file_column_names_basic_parquet(tmp_path, file_type):
"""Test _get_file_column_names with basic parquet file."""
"""Test get_file_column_names with basic parquet file."""
test_data = {
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
Expand All @@ -51,7 +51,7 @@ def test_get_file_column_names_basic_parquet(tmp_path, file_type):


def test_get_file_column_names_nested_fields(tmp_path):
"""Test _get_file_column_names with nested fields in parquet."""
"""Test get_file_column_names with nested fields in parquet."""
schema = pa.schema(
[
pa.field(
Expand All @@ -72,7 +72,7 @@ def test_get_file_column_names_nested_fields(tmp_path):

@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
def test_get_file_column_names_empty_parquet(tmp_path, file_type):
"""Test _get_file_column_names with empty parquet file."""
"""Test get_file_column_names with empty parquet file."""
empty_df = pd.DataFrame()
empty_path = tmp_path / f"empty.{file_type}"
_write_file(empty_df, empty_path, file_type)
Expand All @@ -83,7 +83,7 @@ def test_get_file_column_names_empty_parquet(tmp_path, file_type):

@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
def test_get_file_column_names_large_schema(tmp_path, file_type):
"""Test _get_file_column_names with many columns."""
"""Test get_file_column_names with many columns."""
num_columns = 50
test_data = {f"col_{i}": np.random.randn(10) for i in range(num_columns)}
df = pd.DataFrame(test_data)
Expand All @@ -98,7 +98,7 @@ def test_get_file_column_names_large_schema(tmp_path, file_type):

@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
def test_get_file_column_names_special_characters(tmp_path, file_type):
"""Test _get_file_column_names with special characters in column names."""
"""Test get_file_column_names with special characters in column names."""
special_data = {
"column with spaces": [1],
"column-with-dashes": [2],
Expand All @@ -117,7 +117,7 @@ def test_get_file_column_names_special_characters(tmp_path, file_type):

@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
def test_get_file_column_names_unicode(tmp_path, file_type):
"""Test _get_file_column_names with unicode column names."""
"""Test get_file_column_names with unicode column names."""
unicode_data = {"café": [1], "résumé": [2], "naïve": [3], "façade": [4], "garçon": [5], "über": [6], "schön": [7]}
df_unicode = pd.DataFrame(unicode_data)

Expand All @@ -126,6 +126,22 @@ def test_get_file_column_names_unicode(tmp_path, file_type):
assert get_file_column_names(str(unicode_path), file_type) == df_unicode.columns.tolist()


@pytest.mark.parametrize("file_type", ["parquet", "csv", "json", "jsonl"])
def test_get_file_column_names_with_glob_pattern(tmp_path, file_type):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
for i in range(5):
_write_file(df, tmp_path / f"{i}.{file_type}", file_type)
assert get_file_column_names(f"{tmp_path}/*.{file_type}", file_type) == ["col1", "col2"]


def test_get_file_column_names_with_glob_pattern_error(tmp_path):
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
for i in range(5):
_write_file(df, tmp_path / f"{i}.parquet", "parquet")
with pytest.raises(InvalidFilePathError, match="No files found matching pattern"):
get_file_column_names(f"{tmp_path}/*.csv", "csv")


def test_get_file_column_names_error_handling():
with pytest.raises(InvalidFilePathError, match="🛑 Unsupported file type: 'txt'"):
get_file_column_names("test.txt", "txt")
Expand Down
47 changes: 46 additions & 1 deletion tests/config/test_seed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path

import pandas as pd
import pytest

from data_designer.config.seed import IndexRange, PartitionBlock
from data_designer.config.errors import InvalidFilePathError
from data_designer.config.seed import IndexRange, LocalSeedDatasetReference, PartitionBlock


def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path:
df = pd.DataFrame({"col": [1, 2, 3]})

for i in range(num_files):
file_path = temp_dir / f"partition_{i}.{extension}"
if extension == "parquet":
df.to_parquet(file_path)
elif extension == "csv":
df.to_csv(file_path, index=False)
elif extension == "json":
df.to_json(file_path, orient="records", lines=True)
elif extension == "jsonl":
df.to_json(file_path, orient="records", lines=True)
return temp_dir


def test_index_range_validation():
Expand Down Expand Up @@ -54,3 +74,28 @@ def test_partition_block_to_index_range():
assert index_range.start == 90
assert index_range.end == 104
assert index_range.size == 15


def test_local_seed_dataset_reference_validation(tmp_path: Path):
with pytest.raises(InvalidFilePathError, match="🛑 Path test/dataset.parquet is not a file."):
LocalSeedDatasetReference(dataset="test/dataset.parquet")

# Should not raise an error when referencing supported extensions with wildcard pattern.
create_partitions_in_path(tmp_path, "parquet")
create_partitions_in_path(tmp_path, "csv")
create_partitions_in_path(tmp_path, "json")
create_partitions_in_path(tmp_path, "jsonl")

test_cases = ["parquet", "csv", "json", "jsonl"]
try:
for extension in test_cases:
reference = LocalSeedDatasetReference(dataset=f"{tmp_path}/*.{extension}")
assert reference.dataset == f"{tmp_path}/*.{extension}"
except Exception as e:
pytest.fail(f"Expected no exception, but got {e}")


def test_local_seed_dataset_reference_validation_error(tmp_path: Path):
create_partitions_in_path(tmp_path, "parquet")
with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"):
LocalSeedDatasetReference(dataset=f"{tmp_path}/*.csv")