Skip to content

Commit 690b00f

Browse files
committed
fix check during column name fetch
1 parent 4e6281a commit 690b00f

File tree

3 files changed

+59
-37
lines changed

3 files changed

+59
-37
lines changed

src/data_designer/config/datastore.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,15 @@ class DatastoreSettings(BaseModel):
3232

3333

3434
def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]:
35-
"""Extract column names based on file type."""
35+
"""Extract column names based on file type. Supports glob patterns like '../path/*.parquet'."""
36+
file_path = Path(file_path)
37+
if "*" in str(file_path):
38+
matching_files = sorted(file_path.parent.glob(file_path.name))
39+
if not matching_files:
40+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
41+
logger.info(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
42+
file_path = matching_files[0]
43+
3644
if file_type == "parquet":
3745
try:
3846
schema = pq.read_schema(file_path)
@@ -123,11 +131,17 @@ def _fetch_seed_dataset_column_names_from_datastore(
123131

124132

125133
def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
126-
dataset_path = _validate_dataset_path(dataset_path)
127-
return get_file_column_names(dataset_path, dataset_path.suffix.lower()[1:])
134+
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
135+
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
128136

129137

130-
def _validate_dataset_path(dataset_path: Union[str, Path]) -> Path:
138+
def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:
139+
if allow_glob_pattern and "*" in str(dataset_path):
140+
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
141+
if not any(dataset_path.endswith(wildcard) for wildcard in valid_wild_card_versions):
142+
file_extension = dataset_path.split("*.")[-1]
143+
raise InvalidFilePathError(f"🛑 Path {dataset_path!r} does not contain files of type {file_extension!r}.")
144+
return Path(dataset_path)
131145
if not Path(dataset_path).is_file():
132146
raise InvalidFilePathError("🛑 To upload a dataset to the datastore, you must provide a valid file path.")
133147
if not Path(dataset_path).name.endswith(tuple(VALID_DATASET_FILE_EXTENSIONS)):

tests/config/test_datastore.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _write_file(df, path, file_type):
3636

3737
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
3838
def test_get_file_column_names_basic_parquet(tmp_path, file_type):
39-
"""Test _get_file_column_names with basic parquet file."""
39+
"""Test get_file_column_names with basic parquet file."""
4040
test_data = {
4141
"id": [1, 2, 3],
4242
"name": ["Alice", "Bob", "Charlie"],
@@ -51,7 +51,7 @@ def test_get_file_column_names_basic_parquet(tmp_path, file_type):
5151

5252

5353
def test_get_file_column_names_nested_fields(tmp_path):
54-
"""Test _get_file_column_names with nested fields in parquet."""
54+
"""Test get_file_column_names with nested fields in parquet."""
5555
schema = pa.schema(
5656
[
5757
pa.field(
@@ -72,7 +72,7 @@ def test_get_file_column_names_nested_fields(tmp_path):
7272

7373
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
7474
def test_get_file_column_names_empty_parquet(tmp_path, file_type):
75-
"""Test _get_file_column_names with empty parquet file."""
75+
"""Test get_file_column_names with empty parquet file."""
7676
empty_df = pd.DataFrame()
7777
empty_path = tmp_path / f"empty.{file_type}"
7878
_write_file(empty_df, empty_path, file_type)
@@ -83,7 +83,7 @@ def test_get_file_column_names_empty_parquet(tmp_path, file_type):
8383

8484
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
8585
def test_get_file_column_names_large_schema(tmp_path, file_type):
86-
"""Test _get_file_column_names with many columns."""
86+
"""Test get_file_column_names with many columns."""
8787
num_columns = 50
8888
test_data = {f"col_{i}": np.random.randn(10) for i in range(num_columns)}
8989
df = pd.DataFrame(test_data)
@@ -98,7 +98,7 @@ def test_get_file_column_names_large_schema(tmp_path, file_type):
9898

9999
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
100100
def test_get_file_column_names_special_characters(tmp_path, file_type):
101-
"""Test _get_file_column_names with special characters in column names."""
101+
"""Test get_file_column_names with special characters in column names."""
102102
special_data = {
103103
"column with spaces": [1],
104104
"column-with-dashes": [2],
@@ -117,7 +117,7 @@ def test_get_file_column_names_special_characters(tmp_path, file_type):
117117

118118
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
119119
def test_get_file_column_names_unicode(tmp_path, file_type):
120-
"""Test _get_file_column_names with unicode column names."""
120+
"""Test get_file_column_names with unicode column names."""
121121
unicode_data = {"café": [1], "résumé": [2], "naïve": [3], "façade": [4], "garçon": [5], "über": [6], "schön": [7]}
122122
df_unicode = pd.DataFrame(unicode_data)
123123

@@ -126,6 +126,22 @@ def test_get_file_column_names_unicode(tmp_path, file_type):
126126
assert get_file_column_names(str(unicode_path), file_type) == df_unicode.columns.tolist()
127127

128128

129+
@pytest.mark.parametrize("file_type", ["parquet", "csv", "json", "jsonl"])
130+
def test_get_file_column_names_with_glob_pattern(tmp_path, file_type):
131+
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
132+
for i in range(5):
133+
_write_file(df, tmp_path / f"{i}.{file_type}", file_type)
134+
assert get_file_column_names(f"{tmp_path}/*.{file_type}", file_type) == ["col1", "col2"]
135+
136+
137+
def test_get_file_column_names_with_glob_pattern_error(tmp_path):
138+
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
139+
for i in range(5):
140+
_write_file(df, tmp_path / f"{i}.parquet", "parquet")
141+
with pytest.raises(InvalidFilePathError, match="No files found matching pattern"):
142+
get_file_column_names(f"{tmp_path}/*.csv", "csv")
143+
144+
129145
def test_get_file_column_names_error_handling():
130146
with pytest.raises(InvalidFilePathError, match="🛑 Unsupported file type: 'txt'"):
131147
get_file_column_names("test.txt", "txt")

tests/config/test_seed.py

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from pathlib import Path
5-
import tempfile
65

76
import pandas as pd
87
import pytest
@@ -77,33 +76,26 @@ def test_partition_block_to_index_range():
7776
assert index_range.size == 15
7877

7978

80-
def test_local_seed_dataset_reference_validation():
79+
def test_local_seed_dataset_reference_validation(tmp_path: Path):
8180
with pytest.raises(InvalidFilePathError, match="🛑 Path test/dataset.parquet is not a file."):
8281
LocalSeedDatasetReference(dataset="test/dataset.parquet")
8382

8483
# Should not raise an error when referencing supported extensions with wildcard pattern.
85-
with tempfile.TemporaryDirectory() as temp_dir:
86-
create_partitions_in_path(Path(temp_dir), "parquet")
87-
create_partitions_in_path(Path(temp_dir), "csv")
88-
create_partitions_in_path(Path(temp_dir), "json")
89-
create_partitions_in_path(Path(temp_dir), "jsonl")
90-
91-
test_cases = [
92-
(temp_dir, "parquet"),
93-
(temp_dir, "csv"),
94-
(temp_dir, "json"),
95-
(temp_dir, "jsonl"),
96-
]
97-
98-
try:
99-
for temp_dir, extension in test_cases:
100-
reference = LocalSeedDatasetReference(dataset=f"{temp_dir}/*.{extension}")
101-
assert reference.dataset == f"{temp_dir}/*.{extension}"
102-
except Exception as e:
103-
pytest.fail(f"Expected no exception, but got {e}")
104-
105-
# Should raise an error when referencing a path that does not contain files of the specified type.
106-
with tempfile.TemporaryDirectory() as temp_dir:
107-
create_partitions_in_path(Path(temp_dir), "parquet")
108-
with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"):
109-
LocalSeedDatasetReference(dataset=f"{temp_dir}/*.csv")
84+
create_partitions_in_path(tmp_path, "parquet")
85+
create_partitions_in_path(tmp_path, "csv")
86+
create_partitions_in_path(tmp_path, "json")
87+
create_partitions_in_path(tmp_path, "jsonl")
88+
89+
test_cases = ["parquet", "csv", "json", "jsonl"]
90+
try:
91+
for extension in test_cases:
92+
reference = LocalSeedDatasetReference(dataset=f"{tmp_path}/*.{extension}")
93+
assert reference.dataset == f"{tmp_path}/*.{extension}"
94+
except Exception as e:
95+
pytest.fail(f"Expected no exception, but got {e}")
96+
97+
98+
def test_local_seed_dataset_reference_validation_error(tmp_path: Path):
99+
create_partitions_in_path(tmp_path, "parquet")
100+
with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"):
101+
LocalSeedDatasetReference(dataset=f"{tmp_path}/*.csv")

0 commit comments

Comments
 (0)