|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
| 4 | +from pathlib import Path |
4 | 5 | from unittest.mock import MagicMock, patch |
5 | 6 |
|
6 | 7 | import numpy as np |
@@ -142,6 +143,36 @@ def test_get_file_column_names_with_glob_pattern_error(tmp_path): |
142 | 143 | get_file_column_names(f"{tmp_path}/*.csv", "csv") |
143 | 144 |
|
144 | 145 |
|
| 146 | +def test_get_file_column_names_with_filesystem_parquet(): |
| 147 | + """Test get_file_column_names with filesystem parameter for parquet files.""" |
| 148 | + mock_schema = MagicMock() |
| 149 | + mock_schema.names = ["col1", "col2", "col3"] |
| 150 | + |
| 151 | + with patch("data_designer.config.datastore.pq.read_schema") as mock_read_schema: |
| 152 | + mock_read_schema.return_value = mock_schema |
| 153 | + result = get_file_column_names("datasets/test/file.parquet", "parquet") |
| 154 | + |
| 155 | + assert result == ["col1", "col2", "col3"] |
| 156 | + mock_read_schema.assert_called_once_with(Path("datasets/test/file.parquet")) |
| 157 | + |
| 158 | + |
| 159 | +@pytest.mark.parametrize("file_type", ["json", "jsonl", "csv"]) |
| 160 | +def test_get_file_column_names_with_filesystem_non_parquet(tmp_path, file_type): |
| 161 | + """Test get_file_column_names with file-like objects for non-parquet files.""" |
| 162 | + test_data = pd.DataFrame({"col1": [1], "col2": [2], "col3": [3]}) |
| 163 | + |
| 164 | + # Create a real temporary file |
| 165 | + file_path = tmp_path / f"test_file.{file_type}" |
| 166 | + if file_type in ["json", "jsonl"]: |
| 167 | + test_data.to_json(file_path, orient="records", lines=True) |
| 168 | + else: |
| 169 | + test_data.to_csv(file_path, index=False) |
| 170 | + |
| 171 | + result = get_file_column_names(str(file_path), file_type) |
| 172 | + |
| 173 | + assert result == ["col1", "col2", "col3"] |
| 174 | + |
| 175 | + |
145 | 176 | def test_get_file_column_names_error_handling(): |
146 | 177 | with pytest.raises(InvalidFilePathError, match="🛑 Unsupported file type: 'txt'"): |
147 | 178 | get_file_column_names("test.txt", "txt") |
@@ -177,20 +208,29 @@ def test_fetch_seed_dataset_column_names_local_file(mock_get_file_column_names, |
177 | 208 | assert fetch_seed_dataset_column_names(LocalSeedDatasetReference(dataset="test.parquet")) == ["col1", "col2"] |
178 | 209 |
|
179 | 210 |
|
180 | | -@patch("data_designer.config.datastore.HfFileSystem.open") |
| 211 | +@patch("data_designer.config.datastore.HfFileSystem") |
181 | 212 | @patch("data_designer.config.datastore.get_file_column_names", autospec=True) |
182 | | -def test_fetch_seed_dataset_column_names_remote_file(mock_get_file_column_names, mock_hf_fs_open, datastore_settings): |
| 213 | +def test_fetch_seed_dataset_column_names_remote_file(mock_get_file_column_names, mock_hf_fs, datastore_settings): |
183 | 214 | mock_get_file_column_names.return_value = ["col1", "col2"] |
| 215 | + mock_fs_instance = MagicMock() |
| 216 | + mock_hf_fs.return_value = mock_fs_instance |
| 217 | + |
184 | 218 | assert fetch_seed_dataset_column_names( |
185 | 219 | DatastoreSeedDatasetReference( |
186 | 220 | dataset="test/repo/test.parquet", |
187 | 221 | datastore_settings=datastore_settings, |
188 | 222 | ) |
189 | 223 | ) == ["col1", "col2"] |
190 | | - mock_hf_fs_open.assert_called_once_with( |
191 | | - "datasets/test/repo/test.parquet", |
| 224 | + |
| 225 | + mock_hf_fs.assert_called_once_with( |
| 226 | + endpoint=datastore_settings.endpoint, token=datastore_settings.token, skip_instance_cache=True |
192 | 227 | ) |
193 | 228 |
|
| 229 | + # The get_file_column_names is called with a file-like object from fs.open() |
| 230 | + assert mock_get_file_column_names.call_count == 1 |
| 231 | + call_args = mock_get_file_column_names.call_args |
| 232 | + assert call_args[0][1] == "parquet" |
| 233 | + |
194 | 234 |
|
195 | 235 | def test_resolve_datastore_settings(datastore_settings): |
196 | 236 | with pytest.raises(InvalidConfigError, match="Datastore settings are required"): |
|
0 commit comments