Skip to content

Commit 2e9e4ff

Browse files
authored
fix: update get_file_column_names to take a file reference (#68)
* make get_file_column_names take explicit file reference * add skip_instance_cache back * add hf filesystem logic * update docstring
1 parent 060773c commit 2e9e4ff

File tree

2 files changed

+98
-52
lines changed

2 files changed

+98
-52
lines changed

src/data_designer/config/datastore.py

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,47 +31,74 @@ class DatastoreSettings(BaseModel):
3131
token: Optional[str] = Field(default=None, description="If needed, token to use for authentication.")
3232

3333

34-
def get_file_column_names(file_path: Union[str, Path], file_type: str) -> list[str]:
35-
"""Extract column names based on file type. Supports glob patterns like '../path/*.parquet'."""
36-
file_path = Path(file_path)
37-
if "*" in str(file_path):
38-
matching_files = sorted(file_path.parent.glob(file_path.name))
39-
if not matching_files:
40-
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
41-
logger.debug(f"0️⃣ Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
42-
file_path = matching_files[0]
34+
def get_file_column_names(file_reference: Union[str, Path, HfFileSystem], file_type: str) -> list[str]:
35+
"""Get column names from a dataset file.
36+
37+
Args:
38+
file_reference: Path to the dataset file, or an HfFileSystem object.
39+
file_type: Type of the dataset file. Must be one of: 'parquet', 'json', 'jsonl', 'csv'.
4340
41+
Raises:
42+
InvalidFilePathError: If the file type is not supported.
43+
44+
Returns:
45+
List of column names.
46+
"""
4447
if file_type == "parquet":
4548
try:
46-
schema = pq.read_schema(file_path)
49+
schema = pq.read_schema(file_reference)
4750
if hasattr(schema, "names"):
4851
return schema.names
4952
else:
5053
return [field.name for field in schema]
5154
except Exception as e:
52-
logger.warning(f"Failed to process parquet file {file_path}: {e}")
55+
logger.warning(f"Failed to process parquet file {file_reference}: {e}")
5356
return []
5457
elif file_type in ["json", "jsonl"]:
55-
return pd.read_json(file_path, orient="records", lines=True, nrows=1).columns.tolist()
58+
return pd.read_json(file_reference, orient="records", lines=True, nrows=1).columns.tolist()
5659
elif file_type == "csv":
5760
try:
58-
df = pd.read_csv(file_path, nrows=1)
61+
df = pd.read_csv(file_reference, nrows=1)
5962
return df.columns.tolist()
6063
except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:
61-
logger.warning(f"Failed to process CSV file {file_path}: {e}")
64+
logger.warning(f"Failed to process CSV file {file_reference}: {e}")
6265
return []
6366
else:
6467
raise InvalidFilePathError(f"🛑 Unsupported file type: {file_type!r}")
6568

6669

6770
def fetch_seed_dataset_column_names(seed_dataset_reference: SeedDatasetReference) -> list[str]:
6871
if hasattr(seed_dataset_reference, "datastore_settings"):
69-
return _fetch_seed_dataset_column_names_from_datastore(
72+
return fetch_seed_dataset_column_names_from_datastore(
7073
seed_dataset_reference.repo_id,
7174
seed_dataset_reference.filename,
7275
seed_dataset_reference.datastore_settings,
7376
)
74-
return _fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset)
77+
return fetch_seed_dataset_column_names_from_local_file(seed_dataset_reference.dataset)
78+
79+
80+
def fetch_seed_dataset_column_names_from_datastore(
81+
repo_id: str,
82+
filename: str,
83+
datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
84+
) -> list[str]:
85+
file_type = filename.split(".")[-1]
86+
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
87+
raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
88+
89+
datastore_settings = resolve_datastore_settings(datastore_settings)
90+
fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True)
91+
92+
file_path = _extract_single_file_path_from_glob_pattern_if_present(f"datasets/{repo_id}/{filename}", fs=fs)
93+
94+
with fs.open(file_path) as f:
95+
return get_file_column_names(f, file_type)
96+
97+
98+
def fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
99+
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
100+
dataset_path = _extract_single_file_path_from_glob_pattern_if_present(dataset_path)
101+
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
75102

76103

77104
def resolve_datastore_settings(datastore_settings: DatastoreSettings | dict | None) -> DatastoreSettings:
@@ -114,25 +141,34 @@ def upload_to_hf_hub(
114141
return f"{repo_id}/{filename}"
115142

116143

117-
def _fetch_seed_dataset_column_names_from_datastore(
118-
repo_id: str,
119-
filename: str,
120-
datastore_settings: Optional[Union[DatastoreSettings, dict]] = None,
121-
) -> list[str]:
122-
file_type = filename.split(".")[-1]
123-
if f".{file_type}" not in VALID_DATASET_FILE_EXTENSIONS:
124-
raise InvalidFileFormatError(f"🛑 Unsupported file type: {filename!r}")
125-
126-
datastore_settings = resolve_datastore_settings(datastore_settings)
127-
fs = HfFileSystem(endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True)
128-
129-
with fs.open(f"datasets/{repo_id}/{filename}") as f:
130-
return get_file_column_names(f, file_type)
131-
144+
def _extract_single_file_path_from_glob_pattern_if_present(
145+
file_path: str | Path,
146+
fs: HfFileSystem | None = None,
147+
) -> Path:
148+
file_path = Path(file_path)
132149

133-
def _fetch_seed_dataset_column_names_from_local_file(dataset_path: str | Path) -> list[str]:
134-
dataset_path = _validate_dataset_path(dataset_path, allow_glob_pattern=True)
135-
return get_file_column_names(dataset_path, str(dataset_path).split(".")[-1])
150+
# no glob pattern
151+
if "*" not in str(file_path):
152+
return file_path
153+
154+
# glob pattern with HfFileSystem
155+
if fs is not None:
156+
file_to_check = None
157+
file_extension = file_path.name.split(".")[-1]
158+
for file in fs.ls(str(file_path.parent)):
159+
filename = file["name"]
160+
if filename.endswith(f".{file_extension}"):
161+
file_to_check = filename
162+
if file_to_check is None:
163+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
164+
logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
165+
return Path(file_to_check)
166+
167+
# glob pattern with local file system
168+
if not (matching_files := sorted(file_path.parent.glob(file_path.name))):
169+
raise InvalidFilePathError(f"🛑 No files found matching pattern: {str(file_path)!r}")
170+
logger.debug(f"Using the first matching file in {str(file_path)!r} to determine column names in seed dataset")
171+
return matching_files[0]
136172

137173

138174
def _validate_dataset_path(dataset_path: Union[str, Path], allow_glob_pattern: bool = False) -> Path:

tests/config/test_datastore.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from pathlib import Path
54
from unittest.mock import MagicMock, patch
65

76
import numpy as np
@@ -13,6 +12,7 @@
1312
from data_designer.config.datastore import (
1413
DatastoreSettings,
1514
fetch_seed_dataset_column_names,
15+
fetch_seed_dataset_column_names_from_local_file,
1616
get_file_column_names,
1717
resolve_datastore_settings,
1818
upload_to_hf_hub,
@@ -127,22 +127,6 @@ def test_get_file_column_names_unicode(tmp_path, file_type):
127127
assert get_file_column_names(str(unicode_path), file_type) == df_unicode.columns.tolist()
128128

129129

130-
@pytest.mark.parametrize("file_type", ["parquet", "csv", "json", "jsonl"])
131-
def test_get_file_column_names_with_glob_pattern(tmp_path, file_type):
132-
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
133-
for i in range(5):
134-
_write_file(df, tmp_path / f"{i}.{file_type}", file_type)
135-
assert get_file_column_names(f"{tmp_path}/*.{file_type}", file_type) == ["col1", "col2"]
136-
137-
138-
def test_get_file_column_names_with_glob_pattern_error(tmp_path):
139-
df = pd.DataFrame({"col1": [1, 2, 3], "col2": [4, 5, 6]})
140-
for i in range(5):
141-
_write_file(df, tmp_path / f"{i}.parquet", "parquet")
142-
with pytest.raises(InvalidFilePathError, match="No files found matching pattern"):
143-
get_file_column_names(f"{tmp_path}/*.csv", "csv")
144-
145-
146130
def test_get_file_column_names_with_filesystem_parquet():
147131
"""Test get_file_column_names with filesystem parameter for parquet files."""
148132
mock_schema = MagicMock()
@@ -153,7 +137,7 @@ def test_get_file_column_names_with_filesystem_parquet():
153137
result = get_file_column_names("datasets/test/file.parquet", "parquet")
154138

155139
assert result == ["col1", "col2", "col3"]
156-
mock_read_schema.assert_called_once_with(Path("datasets/test/file.parquet"))
140+
mock_read_schema.assert_called_once_with("datasets/test/file.parquet")
157141

158142

159143
@pytest.mark.parametrize("file_type", ["json", "jsonl", "csv"])
@@ -274,3 +258,29 @@ def test_upload_to_hf_hub_error_handling(datastore_settings):
274258
with patch("data_designer.config.datastore.Path.is_file", autospec=True) as mock_is_file:
275259
mock_is_file.return_value = True
276260
upload_to_hf_hub("test.text", "test.txt", "test/repo", datastore_settings)
261+
262+
263+
@pytest.mark.parametrize("file_type", ["parquet", "json", "jsonl", "csv"])
264+
def test_fetch_seed_dataset_column_names_from_local_file_with_glob(tmp_path, file_type):
265+
"""Test fetch_seed_dataset_column_names_from_local_file with glob pattern matching multiple files."""
266+
test_data = pd.DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 6]})
267+
268+
# Create multiple files with the same schema
269+
for i in range(3):
270+
file_path = tmp_path / f"data_{i}.{file_type}"
271+
_write_file(test_data, file_path, file_type)
272+
273+
# Test glob pattern that matches all files
274+
glob_pattern = str(tmp_path / f"*.{file_type}")
275+
result = fetch_seed_dataset_column_names_from_local_file(glob_pattern)
276+
277+
assert result == ["col1", "col2", "col3"]
278+
279+
280+
@pytest.mark.parametrize("file_type", ["parquet", "csv"])
281+
def test_fetch_seed_dataset_column_names_from_local_file_with_glob_no_matches(tmp_path, file_type):
282+
"""Test fetch_seed_dataset_column_names_from_local_file with glob pattern that matches no files."""
283+
glob_pattern = str(tmp_path / f"nonexistent_*.{file_type}")
284+
285+
with pytest.raises(InvalidFilePathError, match="does not contain files of type"):
286+
fetch_seed_dataset_column_names_from_local_file(glob_pattern)

0 commit comments

Comments
 (0)