@@ -74,39 +74,36 @@ def get_grid_plan( # noqa: C901, PLR0913
7474 template = template ,
7575 )
7676
77- # Determine which dimensions are non-binned (converted to coordinates)
77+ # After grid overrides, determine final spatial dimensions and their chunk sizes
7878 non_binned_dims = set ()
7979 if "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides :
8080 non_binned_dims = set (grid_overrides ["non_binned_dims" ])
8181
82- # Use the spatial dimension names from horizontal_coordinates (which may have been modified by grid overrides)
83- # Extract only the dimension names (not including non-dimension coordinates or non-binned dimensions)
84- # After grid overrides, trace might have been added to horizontal_coordinates
85- # Compute transformed spatial dims: drop non-binned dims, insert trace if present in headers
86- transformed_spatial_dims = []
82+ # Create mapping from dimension name to original chunk size for easy lookup
83+ original_spatial_dims = list (template .spatial_dimension_names )
84+ original_chunks = list (template .full_chunk_shape [:- 1 ]) # Exclude vertical (sample/time) dimension
85+ dim_to_chunk = dict (zip (original_spatial_dims , original_chunks , strict = True ))
86+
87+ # Final spatial dimensions: keep trace and original dims, exclude non-binned dims
88+ final_spatial_dims = []
89+ final_spatial_chunks = []
8790 for name in horizontal_coordinates :
8891 if name in non_binned_dims :
89- continue
90- if name == "trace" or name in horizontal_dimensions :
91- transformed_spatial_dims .append (name )
92-
93- # Recompute chunksize to match transformed dims
94- original_spatial_dims = list (template .spatial_dimension_names )
95- original_chunks = list (template .full_chunk_shape )
96- new_spatial_chunks : list [int ] = []
97- # Insert trace chunk size at N-1 when present, otherwise map remaining dims
98- for dim_name in transformed_spatial_dims :
99- if dim_name == "trace" :
92+ continue # Skip dimensions that became coordinates
93+ if name == "trace" :
94+ # Special handling for trace dimension
10095 chunk_val = int (grid_overrides .get ("chunksize" , 1 )) if "NonBinned" in grid_overrides else 1
101- new_spatial_chunks .append (chunk_val )
102- continue
103- if dim_name in original_spatial_dims :
104- idx = original_spatial_dims .index (dim_name )
105- new_spatial_chunks .append (original_chunks [idx ])
106- chunksize = tuple (new_spatial_chunks + [original_chunks [- 1 ]])
96+ final_spatial_dims .append (name )
97+ final_spatial_chunks .append (chunk_val )
98+ elif name in dim_to_chunk :
99+ # Use original chunk size for known dimensions
100+ final_spatial_dims .append (name )
101+ final_spatial_chunks .append (dim_to_chunk [name ])
102+
103+ chunksize = tuple (final_spatial_chunks + [template .full_chunk_shape [- 1 ]])
107104
108105 dimensions = []
109- for dim_name in transformed_spatial_dims :
106+ for dim_name in final_spatial_dims :
110107 if dim_name not in headers_subset .dtype .names :
111108 continue
112109 dim_unique = np .unique (headers_subset [dim_name ])
0 commit comments