11"""Test live mask chunk size calculation.""" 
22
33import  numpy  as  np 
4+ import  pytest 
45
5- from  mdio .constants  import  INT32_MAX 
66from  mdio .converters .segy  import  _calculate_live_mask_chunksize 
7+ from  mdio .converters .segy  import  _calculate_optimal_chunksize 
78from  mdio .core  import  Dimension 
89from  mdio .core  import  Grid 
910
@@ -20,7 +21,7 @@ def test_small_grid_no_chunking():
2021    grid .live_mask  =  np .ones ((100 , 100 ), dtype = bool )
2122
2223    result  =  _calculate_live_mask_chunksize (grid )
23-     assert  result  ==  - 1 
24+     assert  result  ==  ( 100 ,  100 ) 
2425
2526
2627def  test_large_2d_grid_chunking ():
@@ -37,13 +38,8 @@ def test_large_2d_grid_chunking():
3738
3839    result  =  _calculate_live_mask_chunksize (grid )
3940
40-     # Calculate expected values 
41-     total_elements  =  50000  *  50000 
42-     num_chunks  =  np .ceil (total_elements  /  INT32_MAX ).astype (int )
43-     dim_chunks  =  int (np .ceil (50000  /  np .ceil (np .power (num_chunks , 1  /  2 ))))
44-     expected_chunk_size  =  int (np .ceil (50000  /  dim_chunks ))
45- 
46-     assert  result  ==  (expected_chunk_size , expected_chunk_size )
41+     # TODO(BrianMichell): Avoid magic numbers. 
42+     assert  result  ==  (50000 , 25000 )
4743
4844
4945def  test_large_3d_grid_chunking ():
@@ -62,12 +58,13 @@ def test_large_3d_grid_chunking():
6258    result  =  _calculate_live_mask_chunksize (grid )
6359
6460    # Calculate expected values 
65-     total_elements  =  1500  *  1500  *  1500 
66-     num_chunks  =  np .ceil (total_elements  /  INT32_MAX ).astype (int )
67-     dim_chunks  =  int (np .ceil (1500  /  np .ceil (np .power (num_chunks , 1  /  3 ))))
68-     expected_chunk_size  =  int (np .ceil (1500  /  dim_chunks ))
61+     #  total_elements = 1500 * 1500 * 1500
62+     #  num_chunks = np.ceil(total_elements / INT32_MAX).astype(int)
63+     #  dim_chunks = int(np.ceil(1500 / np.ceil(np.power(num_chunks, 1 / 3))))
64+     #  expected_chunk_size = int(np.ceil(1500 / dim_chunks))
6965
70-     assert  result  ==  (expected_chunk_size , expected_chunk_size , expected_chunk_size )
66+     # assert result == (expected_chunk_size, expected_chunk_size, expected_chunk_size) 
67+     assert  result  ==  (1500 , 1500 , 750 )
7168
7269
7370def  test_uneven_dimensions_chunking ():
@@ -84,14 +81,7 @@ def test_uneven_dimensions_chunking():
8481    grid .live_mask  =  np .ones ((50000 , 50000 ), dtype = bool )
8582
8683    result  =  _calculate_live_mask_chunksize (grid )
87- 
88-     # Calculate expected values 
89-     total_elements  =  50000  *  50000 
90-     num_chunks  =  np .ceil (total_elements  /  INT32_MAX ).astype (int )
91-     dim_chunks  =  int (np .ceil (50000  /  np .ceil (np .power (num_chunks , 1  /  2 ))))
92-     expected_chunk_size  =  int (np .ceil (50000  /  dim_chunks ))
93- 
94-     assert  result  ==  (expected_chunk_size , expected_chunk_size )
84+     assert  result  ==  (50000 , 25000 )
9585
9686
9787def  test_prestack_land_survey_chunking ():
@@ -114,21 +104,7 @@ def test_prestack_land_survey_chunking():
114104    grid .live_mask  =  np .ones ((1000 , 1000 , 100 , 36 ), dtype = bool )
115105
116106    result  =  _calculate_live_mask_chunksize (grid )
117- 
118-     # Calculate expected values 
119-     total_elements  =  1000  *  1000  *  100  *  36 
120-     num_chunks  =  np .ceil (total_elements  /  INT32_MAX ).astype (int )
121-     dim_chunks  =  int (np .ceil (1000  /  np .ceil (np .power (num_chunks , 1  /  4 ))))
122-     expected_chunk_size  =  int (np .ceil (1000  /  dim_chunks ))
123- 
124-     # For a 4D grid, we expect chunk sizes to be distributed across all dimensions 
125-     # The chunk size should be the same for all dimensions since they're all equally important 
126-     assert  result  ==  (
127-         expected_chunk_size ,
128-         expected_chunk_size ,
129-         expected_chunk_size ,
130-         expected_chunk_size ,
131-     )
107+     assert  result  ==  (1000 , 1000 , 100 , 18 )
132108
133109
134110def  test_edge_case_empty_grid ():
@@ -142,4 +118,91 @@ def test_edge_case_empty_grid():
142118    grid .live_mask  =  np .zeros ((0 , 0 ), dtype = bool )
143119
144120    result  =  _calculate_live_mask_chunksize (grid )
145-     assert  result  ==  - 1   # Empty grid shouldn't need chunking 
121+     assert  result  ==  (0 , 0 )
122+ 
123+ 
124+ # Additional tests for _calculate_optimal_chunksize function 
125+ def  test_empty_volume ():
126+     """Test that an empty volume returns its shape.""" 
127+     empty_arr  =  np .zeros ((0 , 10 ), dtype = np .int8 )
128+     result  =  _calculate_optimal_chunksize (empty_arr , 100 )
129+     assert  result  ==  (0 , 10 )
130+ 
131+ 
132+ def  test_nbytes_too_small ():
133+     """Test that a too-small n_bytes value raises a ValueError.""" 
134+     arr  =  np .zeros ((10 ,), dtype = np .int8 )  # itemsize is 1 
135+     with  pytest .raises (
136+         ValueError , match = r"n_bytes is too small to hold even one element" 
137+     ):
138+         _calculate_optimal_chunksize (arr , 0 )
139+ 
140+ 
141+ def  test_one_dim_full_chunk ():
142+     """Test one-dimensional volume where the whole dimension can be used as chunk.""" 
143+     arr  =  np .zeros ((100 ,), dtype = np .int8 )
144+     # With n_bytes = 100, max_elements_allowed = 100, thus optimal chunk should be (100,) 
145+     result  =  _calculate_optimal_chunksize (arr , 100 )
146+     assert  result  ==  (100 ,)
147+ 
148+ 
149+ def  test_two_dim_optimal ():
150+     """Test two-dimensional volume with limited n_bytes. 
151+ 
152+     For a shape of (8,6) with n_bytes=20, the optimal chunk is expected to be (8,2). 
153+     """ 
154+     arr  =  np .zeros ((8 , 6 ), dtype = np .int8 )
155+     result  =  _calculate_optimal_chunksize (arr , 20 )
156+     assert  result  ==  (8 , 2 )
157+ 
158+ 
159+ def  test_three_dim_optimal ():
160+     """Test three-dimensional volume optimal chunk calculation. 
161+ 
162+     For a shape of (9,6,4) with n_bytes=100, the expected chunk is (9,2,4). 
163+     """ 
164+     arr  =  np .zeros ((9 , 6 , 4 ), dtype = np .int8 )
165+     result  =  _calculate_optimal_chunksize (arr , 100 )
166+     assert  result  ==  (9 , 2 , 4 )
167+ 
168+ 
169+ def  test_minimal_chunk_for_large_dtype ():
170+     """Test that n_bytes forcing minimal chunking returns all ones. 
171+ 
172+     Using int32 (itemsize=4) with shape (4,5) and n_bytes=4 yields (1,1). 
173+     """ 
174+     arr  =  np .zeros ((4 , 5 ), dtype = np .int32 )
175+     result  =  _calculate_optimal_chunksize (arr , 4 )
176+     assert  result  ==  (1 , 1 )
177+ 
178+ 
179+ def  test_large_nbytes ():
180+     """Test that a very large n_bytes returns the full volume shape as the optimal chunk.""" 
181+     arr  =  np .zeros ((10 , 10 ), dtype = np .int8 )
182+     result  =  _calculate_optimal_chunksize (arr , 1000 )
183+     assert  result  ==  (10 , 10 )
184+ 
185+ 
186+ def  test_two_dim_non_int8 ():
187+     """Test with a non-int8 dtype where n_bytes exactly covers the full volume in bytes.""" 
188+     arr  =  np .zeros ((6 , 8 ), dtype = np .int16 )  # int16 has itemsize 2 
189+     # Total bytes of full volume = 6*8*2 = 96, so optimal chunk should be (6,8) 
190+     result  =  _calculate_optimal_chunksize (arr , 96 )
191+     assert  result  ==  (6 , 8 )
192+ 
193+ 
194+ def  test_irregular_dimensions ():
195+     """Test volume with prime dimensions where divisors are limited. 
196+ 
197+     For shape (7,5) with n_bytes=35, optimal chunk should be (7,5) since 7*5 = 35. 
198+     """ 
199+     arr  =  np .zeros ((7 , 5 ), dtype = np .int8 )
200+     result  =  _calculate_optimal_chunksize (arr , 35 )
201+     assert  result  ==  (7 , 5 )
202+ 
203+ 
204+ def  test_primes ():
205+     """Test volume with prime dimensions where divisors are limited.""" 
206+     arr  =  np .zeros ((7 , 5 ), dtype = np .int8 )
207+     result  =  _calculate_optimal_chunksize (arr , 23 )
208+     assert  result  ==  (7 , 5 )
0 commit comments