|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
| 4 | +from pathlib import Path |
| 5 | +import tempfile |
| 6 | + |
| 7 | +import pandas as pd |
4 | 8 | import pytest |
5 | 9 |
|
6 | | -from data_designer.config.seed import IndexRange, PartitionBlock |
| 10 | +from data_designer.config.errors import InvalidFilePathError |
| 11 | +from data_designer.config.seed import IndexRange, LocalSeedDatasetReference, PartitionBlock |
| 12 | + |
| 13 | + |
| 14 | +def create_partitions_in_path(temp_dir: Path, extension: str, num_files: int = 2) -> Path: |
| 15 | + df = pd.DataFrame({"col": [1, 2, 3]}) |
| 16 | + |
| 17 | + for i in range(num_files): |
| 18 | + file_path = temp_dir / f"partition_{i}.{extension}" |
| 19 | + if extension == "parquet": |
| 20 | + df.to_parquet(file_path) |
| 21 | + elif extension == "csv": |
| 22 | + df.to_csv(file_path, index=False) |
| 23 | + elif extension == "json": |
| 24 | + df.to_json(file_path, orient="records") |
| 25 | + elif extension == "jsonl": |
| 26 | + df.to_json(file_path, orient="records", lines=True) |
| 27 | + return temp_dir |
7 | 28 |
|
8 | 29 |
|
9 | 30 | def test_index_range_validation(): |
@@ -54,3 +75,35 @@ def test_partition_block_to_index_range(): |
54 | 75 | assert index_range.start == 90 |
55 | 76 | assert index_range.end == 104 |
56 | 77 | assert index_range.size == 15 |
| 78 | + |
| 79 | + |
| 80 | +def test_local_seed_dataset_reference_validation(): |
| 81 | + with pytest.raises(InvalidFilePathError, match="🛑 Path test/dataset.parquet is not a file."): |
| 82 | + LocalSeedDatasetReference(dataset="test/dataset.parquet") |
| 83 | + |
| 84 | + # Should not raise an error when referencing supported extensions with wildcard pattern. |
| 85 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 86 | + create_partitions_in_path(Path(temp_dir), "parquet") |
| 87 | + create_partitions_in_path(Path(temp_dir), "csv") |
| 88 | + create_partitions_in_path(Path(temp_dir), "json") |
| 89 | + create_partitions_in_path(Path(temp_dir), "jsonl") |
| 90 | + |
| 91 | + test_cases = [ |
| 92 | + (temp_dir, "parquet"), |
| 93 | + (temp_dir, "csv"), |
| 94 | + (temp_dir, "json"), |
| 95 | + (temp_dir, "jsonl"), |
| 96 | + ] |
| 97 | + |
| 98 | + try: |
| 99 | + for temp_dir, extension in test_cases: |
| 100 | + reference = LocalSeedDatasetReference(dataset=f"{temp_dir}/*.{extension}") |
| 101 | + assert reference.dataset == f"{temp_dir}/*.{extension}" |
| 102 | + except Exception as e: |
| 103 | + pytest.fail(f"Expected no exception, but got {e}") |
| 104 | + |
| 105 | + # Should raise an error when referencing a path that does not contain files of the specified type. |
| 106 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 107 | + create_partitions_in_path(Path(temp_dir), "parquet") |
| 108 | + with pytest.raises(InvalidFilePathError, match="does not contain files of type 'csv'"): |
| 109 | + LocalSeedDatasetReference(dataset=f"{temp_dir}/*.csv") |
0 commit comments