|
| 1 | +import shutil |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import datetime |
| 5 | + |
| 6 | +sys.path.insert(1, os.path.join(sys.path[0], "..")) |
| 7 | + |
| 8 | +from mabel.adapters.disk import DiskReader, DiskWriter |
| 9 | +from mabel.data import BatchWriter, Reader |
| 10 | + |
| 11 | + |
| 12 | +def test_parquet_row_group_size(): |
| 13 | + """Test that parquet_row_group_size parameter is passed and used correctly""" |
| 14 | + shutil.rmtree("_temp_rowgroup", ignore_errors=True) |
| 15 | + |
| 16 | + # Create a writer with custom row group size |
| 17 | + w = BatchWriter( |
| 18 | + inner_writer=DiskWriter, |
| 19 | + dataset="_temp_rowgroup", |
| 20 | + format="parquet", |
| 21 | + date=datetime.datetime.utcnow().date(), |
| 22 | + schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}], |
| 23 | + parquet_row_group_size=100, # Small row group for testing |
| 24 | + ) |
| 25 | + |
| 26 | + # Write 1000 records |
| 27 | + for i in range(1000): |
| 28 | + w.append({"id": i, "value": f"value_{i}"}) |
| 29 | + |
| 30 | + w.finalize() |
| 31 | + |
| 32 | + # Read back and verify we can read the data |
| 33 | + r = Reader(inner_reader=DiskReader, dataset="_temp_rowgroup") |
| 34 | + records = list(r) |
| 35 | + assert len(records) == 1000, f"Expected 1000 records, got {len(records)}" |
| 36 | + |
| 37 | + # Verify the parquet file has multiple row groups |
| 38 | + import glob |
| 39 | + import pyarrow.parquet as pq |
| 40 | + |
| 41 | + parquet_files = glob.glob("_temp_rowgroup/**/*.parquet", recursive=True) |
| 42 | + assert len(parquet_files) > 0, "No parquet files found" |
| 43 | + |
| 44 | + # Check row groups in the first file |
| 45 | + parquet_file = pq.ParquetFile(parquet_files[0]) |
| 46 | + num_row_groups = parquet_file.num_row_groups |
| 47 | + |
| 48 | + # With 1000 records and row_group_size=100, we should have multiple row groups |
| 49 | + # (exact number depends on how records are distributed across blobs) |
| 50 | + assert num_row_groups > 0, f"Expected at least 1 row group, got {num_row_groups}" |
| 51 | + |
| 52 | + shutil.rmtree("_temp_rowgroup", ignore_errors=True) |
| 53 | + |
| 54 | + |
| 55 | +def test_parquet_sorting(): |
| 56 | + """Test that sort_by parameter sorts records correctly""" |
| 57 | + shutil.rmtree("_temp_sort", ignore_errors=True) |
| 58 | + |
| 59 | + # Create a writer with sorting |
| 60 | + w = BatchWriter( |
| 61 | + inner_writer=DiskWriter, |
| 62 | + dataset="_temp_sort", |
| 63 | + format="parquet", |
| 64 | + date=datetime.datetime.utcnow().date(), |
| 65 | + schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}], |
| 66 | + sort_by="id", # Sort by id column |
| 67 | + parquet_row_group_size=5000, |
| 68 | + ) |
| 69 | + |
| 70 | + # Write records in reverse order |
| 71 | + for i in range(100, 0, -1): |
| 72 | + w.append({"id": i, "value": f"value_{i}"}) |
| 73 | + |
| 74 | + w.finalize() |
| 75 | + |
| 76 | + # Read back and verify the data is sorted |
| 77 | + r = Reader(inner_reader=DiskReader, dataset="_temp_sort") |
| 78 | + records = list(r) |
| 79 | + |
| 80 | + assert len(records) == 100, f"Expected 100 records, got {len(records)}" |
| 81 | + |
| 82 | + # Check that records are sorted by id |
| 83 | + ids = [record["id"] for record in records] |
| 84 | + assert ids == list(range(1, 101)), f"Records are not sorted correctly: {ids[:10]}..." |
| 85 | + |
| 86 | + shutil.rmtree("_temp_sort", ignore_errors=True) |
| 87 | + |
| 88 | + |
| 89 | +def test_parquet_sorting_descending(): |
| 90 | + """Test that sort_by parameter can sort in descending order""" |
| 91 | + shutil.rmtree("_temp_sort_desc", ignore_errors=True) |
| 92 | + |
| 93 | + # Create a writer with descending sorting |
| 94 | + w = BatchWriter( |
| 95 | + inner_writer=DiskWriter, |
| 96 | + dataset="_temp_sort_desc", |
| 97 | + format="parquet", |
| 98 | + date=datetime.datetime.utcnow().date(), |
| 99 | + schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}], |
| 100 | + sort_by=[("id", "descending")], # Sort by id in descending order |
| 101 | + parquet_row_group_size=5000, |
| 102 | + ) |
| 103 | + |
| 104 | + # Write records in random order |
| 105 | + for i in [50, 10, 90, 30, 70, 20, 80, 40, 60, 100]: |
| 106 | + w.append({"id": i, "value": f"value_{i}"}) |
| 107 | + |
| 108 | + w.finalize() |
| 109 | + |
| 110 | + # Read back and verify the data is sorted in descending order |
| 111 | + r = Reader(inner_reader=DiskReader, dataset="_temp_sort_desc") |
| 112 | + records = list(r) |
| 113 | + |
| 114 | + assert len(records) == 10, f"Expected 10 records, got {len(records)}" |
| 115 | + |
| 116 | + # Check that records are sorted by id in descending order |
| 117 | + ids = [record["id"] for record in records] |
| 118 | + expected_ids = [100, 90, 80, 70, 60, 50, 40, 30, 20, 10] |
| 119 | + assert ids == expected_ids, f"Records are not sorted correctly in descending order: {ids}" |
| 120 | + |
| 121 | + shutil.rmtree("_temp_sort_desc", ignore_errors=True) |
| 122 | + |
| 123 | + |
| 124 | +def test_parquet_default_row_group_size(): |
| 125 | + """Test that default row group size is 5000""" |
| 126 | + shutil.rmtree("_temp_default", ignore_errors=True) |
| 127 | + |
| 128 | + # Create a writer without specifying row group size |
| 129 | + w = BatchWriter( |
| 130 | + inner_writer=DiskWriter, |
| 131 | + dataset="_temp_default", |
| 132 | + format="parquet", |
| 133 | + date=datetime.datetime.utcnow().date(), |
| 134 | + schema=[{"name": "id", "type": "INTEGER"}, {"name": "value", "type": "VARCHAR"}], |
| 135 | + ) |
| 136 | + |
| 137 | + # Write some records |
| 138 | + for i in range(100): |
| 139 | + w.append({"id": i, "value": f"value_{i}"}) |
| 140 | + |
| 141 | + w.finalize() |
| 142 | + |
| 143 | + # Read back and verify we can read the data |
| 144 | + r = Reader(inner_reader=DiskReader, dataset="_temp_default") |
| 145 | + records = list(r) |
| 146 | + assert len(records) == 100, f"Expected 100 records, got {len(records)}" |
| 147 | + |
| 148 | + shutil.rmtree("_temp_default", ignore_errors=True) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == "__main__": # pragma: no cover |
| 152 | + from tests.helpers.runner import run_tests |
| 153 | + |
| 154 | + run_tests() |
0 commit comments