@@ -132,16 +132,44 @@ def get_axis_dim(self, axis: _XGRID_AXES) -> int:
132132 def localize (self , position : dict [_XGRID_AXES , tuple [int , float ]], dims : list [str ]) -> dict [str , tuple [int , float ]]:
133133 """
134134 Uses the grid context (i.e., the staggering of the grid) to convert a position relative
135- to the F-points in the grid to a position relative to the dimensions of the array
136- of interest.
135+ to the F-points in the grid to a position relative to the staggered grid the array
136+ of interest is defined on.
137+
138+ Uses dimensions of the DataArray to determine the staggered grid.
139+
140+ Parameters
141+ ----------
142+ position : dict
143+ A mapping of the axis to a tuple of (index, barycentric coordinate) for the
144+ F-points in the grid.
145+ dims : list[str]
146+ A list of dimension names that the DataArray is defined on. This is used to determine
147+ the staggering of the grid and which axis each dimension corresponds to.
148+
149+ Returns
150+ -------
151+ dict[str, tuple[int, float]]
152+ A mapping of the dimension names to a tuple of (index, barycentric coordinate) for
153+ the staggered grid the DataArray is defined on.
154+
155+ Example
156+ -------
157+ >>> position = {'X': (5, 0.51), 'Y': (
158+ 10, 0.25), 'Z': (3, 0.75)}
159+ >>> dims = ['time', 'depth', 'YC', 'XC']
160+ >>> grid.localize(position, dims)
161+ {'depth': (3, 0.75), 'YC': (9, 0.75), 'XC': (5, 0.01)}
137162 """
138163 axis_to_var = {get_axis_from_dim_name (self .xgcm_grid .axes , dim ): dim for dim in dims }
139164 var_positions = {
140165 axis : get_xgcm_position_from_dim_name (self .xgcm_grid .axes , dim ) for axis , dim in axis_to_var .items ()
141166 }
142167 return {
143168 axis_to_var [axis ]: _convert_center_pos_to_fpoint (
144- index = index , bcoord = bcoord , position = var_positions [axis ], f_points_position = self ._fpoint_info [axis ]
169+ index = index ,
170+ bcoord = bcoord ,
171+ xgcm_position = var_positions [axis ],
172+ f_points_xgcm_position = self ._fpoint_info [axis ],
145173 )
146174 for axis , (index , bcoord ) in position .items ()
147175 }
@@ -204,12 +232,13 @@ def search(self, z, y, x, ei=None):
204232
205233 @cached_property
206234 def _fpoint_info (self ):
235+ """Returns a mapping of the spatial axes in the Grid to their XGCM positions."""
207236 xgcm_axes = self .xgcm_grid .axes
208237 f_point_positions = ["left" , "right" , "inner" , "outer" ]
209238 axis_position_mapping = {}
210239 for axis in self .axes :
211240 coords = xgcm_axes [axis ].coords
212- edge_positions = list ( filter ( lambda x : x in f_point_positions , coords .keys ()))
241+ edge_positions = [ pos for pos in coords .keys () if pos in f_point_positions ]
213242 assert len (edge_positions ) == 1 , f"Axis { axis } has multiple edge positions: { edge_positions } "
214243 axis_position_mapping [axis ] = edge_positions [0 ]
215244
@@ -370,10 +399,16 @@ def _search_1d_array(
370399
371400
372401def _convert_center_pos_to_fpoint (
373- * , index : int , bcoord : float , position : _XGCM_AXIS_POSITION , f_points_position : _XGCM_AXIS_POSITION
402+ * , index : int , bcoord : float , xgcm_position : _XGCM_AXIS_POSITION , f_points_xgcm_position : _XGCM_AXIS_POSITION
374403) -> tuple [int , float ]:
375- """Converts a position relative to the center point along an axis to a reposition relative to the cell edges."""
376- if position != "center" :
404+ """Converts a physical position relative to the cell edges defined in the grid to be relative to the center point.
405+
406+ This is used to "localize" a position to be relative to the staggered grid at which the field is defined, so that
407+ it can be easily interpolated.
408+
409+ This also handles different model input cell edges and centers are staggered in different directions (e.g., with NEMO and MITgcm).
410+ """
411+ if xgcm_position != "center" : # Data is already defined on the F points
377412 return index , bcoord
378413
379414 bcoord = bcoord - 0.5
@@ -382,7 +417,7 @@ def _convert_center_pos_to_fpoint(
382417 index -= 1
383418
384419 # Correct relative to the f-point position
385- if f_points_position in ["inner" , "right" ]:
420+ if f_points_xgcm_position in ["inner" , "right" ]:
386421 index += 1
387422
388423 return index , bcoord
0 commit comments