| 
1 | 1 | """Test live mask chunk size calculation."""  | 
2 | 2 | 
 
  | 
3 |  | -import numpy as np  | 
4 |  | - | 
5 |  | -from mdio.core import Dimension  | 
6 |  | -from mdio.core import Grid  | 
7 |  | -from mdio.core.grid import _calculate_live_mask_chunksize  | 
8 |  | -from mdio.core.grid import _calculate_optimal_chunksize  | 
9 |  | -from mdio.core.grid import _EmptyGrid  | 
10 |  | - | 
11 |  | - | 
12 |  | -def test_small_grid_no_chunking():  | 
13 |  | -    """Test that small grids return -1 (no chunking needed)."""  | 
14 |  | -    # Create a small grid that fits within INT32_MAX  | 
15 |  | -    dims = [  | 
16 |  | -        Dimension(coords=range(0, 100, 1), name="dim1"),  | 
17 |  | -        Dimension(coords=range(0, 100, 1), name="dim2"),  | 
18 |  | -        Dimension(coords=range(0, 100, 1), name="sample"),  | 
19 |  | -    ]  | 
20 |  | -    grid = Grid(dims=dims)  | 
21 |  | -    grid.live_mask = _EmptyGrid((100, 100), dtype=np.bool)  | 
22 |  | - | 
23 |  | -    result = _calculate_live_mask_chunksize(grid)  | 
24 |  | -    assert result == (100, 100)  | 
25 |  | - | 
26 |  | - | 
27 |  | -def test_grid_without_live_mask():  | 
28 |  | -    """Test that a grid without a live mask set up yet."""  | 
29 |  | -    dims = [  | 
30 |  | -        Dimension(coords=range(0, 100, 1), name="dim1"),  | 
31 |  | -        Dimension(coords=range(0, 100, 1), name="dim2"),  | 
32 |  | -        Dimension(coords=range(0, 100, 1), name="sample"),  | 
33 |  | -    ]  | 
34 |  | -    grid = Grid(dims=dims)  | 
35 |  | - | 
36 |  | -    result = _calculate_live_mask_chunksize(grid)  | 
37 |  | -    assert result == (100, 100)  | 
38 |  | - | 
39 |  | - | 
40 |  | -def test_large_2d_grid_chunking():  | 
41 |  | -    """Test exact chunk size calculation for a 2D grid that exceeds INT32_MAX."""  | 
42 |  | -    # Create a grid that exceeds INT32_MAX (2,147,483,647)  | 
43 |  | -    # Using 50,000 x 50,000 = 2,500,000,000 elements  | 
44 |  | -    dims = [  | 
45 |  | -        Dimension(coords=range(0, 50000, 1), name="dim1"),  | 
46 |  | -        Dimension(coords=range(0, 50000, 1), name="dim2"),  | 
47 |  | -        Dimension(coords=range(0, 100, 1), name="sample"),  | 
48 |  | -    ]  | 
49 |  | -    grid = Grid(dims=dims)  | 
50 |  | -    grid.live_mask = _EmptyGrid((50000, 50000), dtype=np.bool)  | 
51 |  | - | 
52 |  | -    result = _calculate_live_mask_chunksize(grid)  | 
53 |  | - | 
54 |  | -    # TODO(BrianMichell): Avoid magic numbers.  | 
55 |  | -    assert result == (25000, 25000)  | 
56 |  | - | 
57 |  | - | 
58 |  | -def test_large_3d_grid_chunking():  | 
59 |  | -    """Test exact chunk size calculation for a 3D grid that exceeds INT32_MAX."""  | 
60 |  | -    # Create a 3D grid that exceeds INT32_MAX  | 
61 |  | -    # Using 1500 x 1500 x 1500 = 3,375,000,000 elements  | 
62 |  | -    dims = [  | 
63 |  | -        Dimension(coords=range(0, 1500, 1), name="dim1"),  | 
64 |  | -        Dimension(coords=range(0, 1500, 1), name="dim2"),  | 
65 |  | -        Dimension(coords=range(0, 1500, 1), name="dim3"),  | 
66 |  | -        Dimension(coords=range(0, 100, 1), name="sample"),  | 
67 |  | -    ]  | 
68 |  | -    grid = Grid(dims=dims)  | 
69 |  | -    grid.live_mask = _EmptyGrid((1500, 1500, 1500), dtype=np.bool)  | 
70 |  | - | 
71 |  | -    result = _calculate_live_mask_chunksize(grid)  | 
72 |  | - | 
73 |  | -    # Calculate expected values  | 
74 |  | -    # total_elements = 1500 * 1500 * 1500  | 
75 |  | -    # num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)  | 
76 |  | -    # dim_chunks = int(np.ceil(1500 / np.ceil(np.power(num_chunks, 1 / 3))))  | 
77 |  | -    # expected_chunk_size = int(np.ceil(1500 / dim_chunks))  | 
78 |  | - | 
79 |  | -    # assert result == (expected_chunk_size, expected_chunk_size, expected_chunk_size)  | 
80 |  | -    assert result == (750, 750, 750)  | 
81 |  | - | 
82 |  | - | 
83 |  | -def test_prestack_land_survey_chunking():  | 
84 |  | -    """Test exact chunk size calculation for a dense pre-stack land survey grid."""  | 
85 |  | -    # Create a dense pre-stack land survey grid that exceeds INT32_MAX  | 
86 |  | -    # Using realistic dimensions:  | 
87 |  | -    # - 1000 shot points  | 
88 |  | -    # - 1000 receiver points  | 
89 |  | -    # - 100 offsets  | 
90 |  | -    # - 36 azimuths  | 
91 |  | -    # Total elements: 1000 * 1000 * 100 * 36 = 3,600,000,000 elements  | 
92 |  | -    dims = [  | 
93 |  | -        Dimension(coords=range(0, 1000, 1), name="shot_point"),  | 
94 |  | -        Dimension(coords=range(0, 1000, 1), name="receiver_point"),  | 
95 |  | -        Dimension(coords=range(0, 100, 1), name="offset"),  | 
96 |  | -        Dimension(coords=range(0, 36, 1), name="azimuth"),  | 
97 |  | -        Dimension(coords=range(0, 1000, 1), name="sample"),  | 
98 |  | -    ]  | 
99 |  | -    grid = Grid(dims=dims)  | 
100 |  | -    grid.live_mask = _EmptyGrid((1000, 1000, 100, 36), dtype=np.bool)  | 
101 |  | - | 
102 |  | -    result = _calculate_live_mask_chunksize(grid)  | 
103 |  | -    assert result == (334, 334, 100, 36)  | 
104 |  | - | 
105 |  | - | 
106 |  | -def test_one_dim_full_chunk():  | 
107 |  | -    """Test one-dimensional volume where the whole dimension can be used as chunk."""  | 
108 |  | -    arr = _EmptyGrid((100,), dtype=np.int8)  | 
109 |  | -    # With n_bytes = 100, max_elements_allowed = 100, thus optimal chunk should be (100,)  | 
110 |  | -    result = _calculate_optimal_chunksize(arr, 100)  | 
111 |  | -    assert result == (100,)  | 
112 |  | - | 
113 |  | - | 
114 |  | -def test_two_dim_optimal():  | 
115 |  | -    """Test two-dimensional volume with limited n_bytes.  | 
116 |  | -
  | 
117 |  | -    For a shape of (8,6) with n_bytes=20, the optimal chunk is expected to be (8,2).  | 
118 |  | -    """  | 
119 |  | -    arr = _EmptyGrid((8, 6), dtype=np.int8)  | 
120 |  | -    result = _calculate_optimal_chunksize(arr, 20)  | 
121 |  | -    assert result == (4, 4)  | 
 | 3 | +from typing import TYPE_CHECKING  | 
122 | 4 | 
 
  | 
123 |  | - | 
124 |  | -def test_three_dim_optimal():  | 
125 |  | -    """Test three-dimensional volume optimal chunk calculation.  | 
126 |  | -
  | 
127 |  | -    For a shape of (9,6,4) with n_bytes=100, the expected chunk is (9,2,4).  | 
128 |  | -    """  | 
129 |  | -    arr = _EmptyGrid((9, 6, 4), dtype=np.int8)  | 
130 |  | -    result = _calculate_optimal_chunksize(arr, 100)  | 
131 |  | -    assert result == (5, 5, 4)  | 
132 |  | - | 
133 |  | - | 
134 |  | -def test_minimal_chunk_for_large_dtype():  | 
135 |  | -    """Test that n_bytes forcing minimal chunking returns all ones.  | 
136 |  | -
  | 
137 |  | -    Using int32 (itemsize=4) with shape (4,5) and n_bytes=4 yields (1,1).  | 
138 |  | -    """  | 
139 |  | -    arr = _EmptyGrid((4, 5), dtype=np.int32)  | 
140 |  | -    result = _calculate_optimal_chunksize(arr, 4)  | 
141 |  | -    assert result == (1, 1)  | 
142 |  | - | 
143 |  | - | 
144 |  | -def test_large_nbytes():  | 
145 |  | -    """Test that a very large n_bytes returns the full volume shape as the optimal chunk."""  | 
146 |  | -    arr = _EmptyGrid((10, 10), dtype=np.int8)  | 
147 |  | -    result = _calculate_optimal_chunksize(arr, 1000)  | 
148 |  | -    assert result == (10, 10)  | 
149 |  | - | 
150 |  | - | 
151 |  | -def test_two_dim_non_int8():  | 
152 |  | -    """Test with a non-int8 dtype where n_bytes exactly covers the full volume in bytes."""  | 
153 |  | -    arr = _EmptyGrid((6, 8), dtype=np.int16)  # int16 has itemsize 2  | 
154 |  | -    # Total bytes of full volume = 6*8*2 = 96, so optimal chunk should be (6,8)  | 
155 |  | -    result = _calculate_optimal_chunksize(arr, 96)  | 
156 |  | -    assert result == (6, 8)  | 
157 |  | - | 
158 |  | - | 
159 |  | -def test_irregular_dimensions():  | 
160 |  | -    """Test volume with prime dimensions where divisors are limited.  | 
161 |  | -
  | 
162 |  | -    For shape (7,5) with n_bytes=35, optimal chunk should be (7,5) since 7*5 = 35.  | 
163 |  | -    """  | 
164 |  | -    arr = _EmptyGrid((7, 5), dtype=np.int8)  | 
165 |  | -    result = _calculate_optimal_chunksize(arr, 35)  | 
166 |  | -    assert result == (7, 5)  | 
167 |  | - | 
168 |  | - | 
169 |  | -def test_primes():  | 
170 |  | -    """Test volume with prime dimensions where divisors are limited."""  | 
171 |  | -    arr = _EmptyGrid((7, 5), dtype=np.int8)  | 
172 |  | -    result = _calculate_optimal_chunksize(arr, 23)  | 
173 |  | -    assert result == (4, 4)  | 
174 |  | - | 
175 |  | - | 
176 |  | -def test_grid_gambit():  | 
177 |  | -    """Test various chunk size scenarios with different array dimensions."""  | 
178 |  | -    from mdio.constants import INT32_MAX  | 
179 |  | - | 
180 |  | -    # Dictionary of test cases with different array shapes  | 
181 |  | -    live_mask_chunks_examples = {  | 
182 |  | -        "smaller_2G_asym": (1024, 8192),  | 
183 |  | -        "small_2G_square": (32768, 32768),  | 
184 |  | -        "right_below_2G": (46340, 46340),  | 
185 |  | -        "right_above_2G": (46341, 46341),  | 
186 |  | -        "above_2G_v1": (86341, 96341),  | 
187 |  | -        "above_2G_v2": (55000, 47500),  | 
188 |  | -        "above_2G_v2_asym": (55000, 97500),  | 
189 |  | -        "above_4G_v2_asym": (100000, 100000),  | 
190 |  | -        "below_2G_4D": (215, 215, 215, 215),  | 
191 |  | -        "above_3G_4D": (216, 216, 216, 216),  | 
192 |  | -        "above_3G_4D_asym": (512, 216, 512, 400),  | 
193 |  | -        "below_2G_5D": (73, 73, 73, 73, 73),  | 
194 |  | -        "above_3G_5D": (74, 74, 74, 74, 74),  | 
195 |  | -        "above_3G_5D_asym": (512, 17, 43, 200, 50),  | 
196 |  | -    }  | 
197 |  | - | 
198 |  | -    # Test each case  | 
199 |  | -    for kind, shape in live_mask_chunks_examples.items():  | 
200 |  | -        # Create dimensions for the grid  | 
201 |  | -        dims = [  | 
202 |  | -            Dimension(coords=range(0, dim_size, 1), name=f"dim{i}")  | 
203 |  | -            for i, dim_size in enumerate(shape)  | 
204 |  | -        ]  | 
205 |  | -        # Add sample dimension  | 
206 |  | -        dims.append(Dimension(coords=range(0, 100, 1), name="sample"))  | 
207 |  | - | 
208 |  | -        # Create grid and set live mask  | 
209 |  | -        grid = Grid(dims=dims)  | 
210 |  | -        grid.live_mask = _EmptyGrid(shape, dtype=np.bool)  | 
211 |  | - | 
212 |  | -        result = _calculate_live_mask_chunksize(grid)  | 
213 |  | - | 
214 |  | -        # Verify that the chunk size is valid  | 
215 |  | -        assert all(chunk > 0 for chunk in result), f"Invalid chunk size for {kind}"  | 
216 |  | -        assert len(result) == len(shape), f"Dimension mismatch for {kind}"  | 
217 |  | - | 
218 |  | -        # # Calculate total elements in chunks  | 
 | 5 | +import numpy as np  | 
 | 6 | +import pytest  | 
 | 7 | + | 
 | 8 | +from mdio.constants import INT32_MAX  | 
 | 9 | +from mdio.core.utils_write import get_constrained_chunksize  | 
 | 10 | +from mdio.core.utils_write import get_live_mask_chunksize  | 
 | 11 | + | 
 | 12 | + | 
 | 13 | +if TYPE_CHECKING:  | 
 | 14 | +    from numpy.typing import DTypeLike  | 
 | 15 | + | 
 | 16 | + | 
 | 17 | +@pytest.mark.parametrize(  | 
 | 18 | +    ("shape", "dtype", "limit", "expected_chunks"),  | 
 | 19 | +    [  | 
 | 20 | +        ((100,), "int8", 100, (100,)),  # 1D full chunk  | 
 | 21 | +        ((8, 6), "int8", 20, (4, 4)),  # 2D adjusted int8  | 
 | 22 | +        ((6, 8), "int16", 96, (6, 8)),  # 2D small int16  | 
 | 23 | +        ((9, 6, 4), "int8", 100, (5, 5, 4)),  # 3D adjusted  | 
 | 24 | +        ((4, 5), "int32", 4, (1, 1)),  # test minimum edge case  | 
 | 25 | +        ((10, 10), "int8", 1000, (10, 10)),  # big limit  | 
 | 26 | +        ((7, 5), "int8", 35, (7, 5)),  # test full primes  | 
 | 27 | +        ((7, 5), "int8", 23, (4, 4)),  # test adjusted primes  | 
 | 28 | +    ],  | 
 | 29 | +)  | 
 | 30 | +def test_auto_chunking(  | 
 | 31 | +    shape: tuple[int, ...],  | 
 | 32 | +    dtype: "DTypeLike",  | 
 | 33 | +    limit: int,  | 
 | 34 | +    expected_chunks: tuple[int, ...],  | 
 | 35 | +) -> None:  | 
 | 36 | +    """Test automatic chunking based on size limit and an array spec."""  | 
 | 37 | +    result = get_constrained_chunksize(shape, dtype, limit)  | 
 | 38 | +    assert result == expected_chunks  | 
 | 39 | + | 
 | 40 | + | 
 | 41 | +class TestAutoChunkLiveMask:  | 
 | 42 | +    """Test class for live mask auto chunking."""  | 
 | 43 | + | 
 | 44 | +    @pytest.mark.parametrize(  | 
 | 45 | +        ("shape", "expected_chunks"),  | 
 | 46 | +        [  | 
 | 47 | +            ((100,), (100,)),  # small 1d  | 
 | 48 | +            ((100, 100), (100, 100)),  # small 2d  | 
 | 49 | +            ((50000, 50000), (25000, 25000)),  # large 2d  | 
 | 50 | +            ((1500, 1500, 1500), (750, 750, 750)),  # large 3d  | 
 | 51 | +            ((1000, 1000, 100, 36), (334, 334, 100, 36)),  # large 4d  | 
 | 52 | +        ],  | 
 | 53 | +    )  | 
 | 54 | +    def test_auto_chunk_live_mask(  | 
 | 55 | +        self,  | 
 | 56 | +        shape: tuple[int, ...],  | 
 | 57 | +        expected_chunks: tuple[int, ...],  | 
 | 58 | +    ) -> None:  | 
 | 59 | +        """Test auto chunked live mask is within expected number of bytes."""  | 
 | 60 | +        result = get_live_mask_chunksize(shape)  | 
 | 61 | +        assert result == expected_chunks  | 
 | 62 | + | 
 | 63 | +    @pytest.mark.parametrize(  | 
 | 64 | +        "shape",  | 
 | 65 | +        [  | 
 | 66 | +            # Below are >500MiB. Smaller ones tested above  | 
 | 67 | +            (32768, 32768),  | 
 | 68 | +            (46341, 46341),  | 
 | 69 | +            (86341, 96341),  | 
 | 70 | +            (55000, 97500),  | 
 | 71 | +            (100000, 100000),  | 
 | 72 | +            (1024, 1024, 1024),  | 
 | 73 | +            (215, 215, 215, 215),  | 
 | 74 | +            (512, 216, 512, 400),  | 
 | 75 | +            (74, 74, 74, 74, 74),  | 
 | 76 | +            (512, 17, 43, 200, 50),  | 
 | 77 | +        ],  | 
 | 78 | +    )  | 
 | 79 | +    def test_auto_chunk_live_mask_nbytes(self, shape: tuple[int, ...]) -> None:  | 
 | 80 | +        """Test auto chunked live mask is within expected number of bytes."""  | 
 | 81 | +        result = get_live_mask_chunksize(shape)  | 
219 | 82 |         chunk_elements = np.prod(result)  | 
220 |  | -        if kind in [  | 
221 |  | -            "right_above_2G",  | 
222 |  | -            "above_2G_v2",  | 
223 |  | -            "above_2G_v2_asym",  | 
224 |  | -            "above_4G_v2_asym",  | 
225 |  | -            "above_3G_4D_asym",  | 
226 |  | -        ]:  | 
227 |  | -            # TODO(BrianMichell): Our implementation is taking "limit" pretty liberally.  | 
228 |  | -            # This is not overtly an issue because we are well below the 2GiB limit,  | 
229 |  | -            # but it could be improved.  | 
230 |  | -            assert (  | 
231 |  | -                chunk_elements <= (INT32_MAX // 4) * 1.5  | 
232 |  | -            ), f"Chunk too large for {kind}"  | 
233 |  | -        else:  | 
234 |  | -            assert chunk_elements <= INT32_MAX // 4, f"Chunk too large for {kind}"  | 
 | 83 | + | 
 | 84 | +        # We want them to be 500MB +/- 25%  | 
 | 85 | +        assert chunk_elements > INT32_MAX // 4 * 0.75  | 
 | 86 | +        assert chunk_elements < INT32_MAX // 4 * 1.25  | 
0 commit comments