@@ -94,64 +94,85 @@ def _slice_face_indices(
9494 Indicates whether to store the original grids indices. Passing `True` stores the original face centers,
9595 other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True)
9696 """
97+ from uxarray .grid import Grid
98+
9799 if inclusive is False :
98100 raise ValueError ("Exclusive slicing is not yet supported." )
99101
100- from uxarray .grid import Grid
101-
102102 ds = grid ._ds
103-
104- indices = np .asarray (indices , dtype = INT_DTYPE )
105-
106- if indices .ndim == 0 :
107- indices = np .expand_dims (indices , axis = 0 )
108-
109- face_indices = indices
103+ face_indices = np .atleast_1d (np .asarray (indices , dtype = INT_DTYPE ))
110104
111105 # nodes of each face (inclusive)
112106 node_indices = np .unique (grid .face_node_connectivity .values [face_indices ].ravel ())
113107 node_indices = node_indices [node_indices != INT_FILL_VALUE ]
114108
115- # index original dataset to obtain a 'subgrid'
109+ # Index Node and Face variables
116110 ds = ds .isel (n_node = node_indices )
117111 ds = ds .isel (n_face = face_indices )
118112
119- # Only slice edge dimension if we already have the connectivity
120- if "face_edge_connectivity" in grid . _ds :
113+ # Only slice edge dimension if we have the face edge connectivity
114+ if "face_edge_connectivity" in ds :
121115 edge_indices = np .unique (
122116 grid .face_edge_connectivity .values [face_indices ].ravel ()
123117 )
124118 edge_indices = edge_indices [edge_indices != INT_FILL_VALUE ]
125119 ds = ds .isel (n_edge = edge_indices )
126120 ds ["subgrid_edge_indices" ] = xr .DataArray (edge_indices , dims = ["n_edge" ])
121+ # Otherwise, drop any edge variables
127122 else :
123+ if "n_edge" in ds .dims :
124+ ds = ds .drop_dims (["n_edge" ])
128125 edge_indices = None
129126
130127 ds ["subgrid_node_indices" ] = xr .DataArray (node_indices , dims = ["n_node" ])
131128 ds ["subgrid_face_indices" ] = xr .DataArray (face_indices , dims = ["n_face" ])
132129
133- # mapping to update existing connectivity
134- node_indices_dict = {
135- key : val for key , val in zip (node_indices , np .arange (0 , len (node_indices )))
136- }
130+ # Construct updated Node Index Map
131+ node_indices_dict = {orig : new for new , orig in enumerate (node_indices )}
137132 node_indices_dict [INT_FILL_VALUE ] = INT_FILL_VALUE
138133
139- for conn_name in grid ._ds .data_vars :
140- # update or drop connectivity variables to correctly point to the new index of each element
134+ # Construct updated Edge Index Map
135+ if edge_indices is not None :
136+ edge_indices_dict = {orig : new for new , orig in enumerate (edge_indices )}
137+ edge_indices_dict [INT_FILL_VALUE ] = INT_FILL_VALUE
138+ else :
139+ edge_indices_dict = None
140+
141+ def map_node_indices (i ):
142+ return node_indices_dict .get (i , INT_FILL_VALUE )
143+
144+ if edge_indices is not None :
145+
146+ def map_edge_indices (i ):
147+ return edge_indices_dict .get (i , INT_FILL_VALUE )
148+ else :
149+ map_edge_indices = None
150+
151+ for conn_name in list (ds .data_vars ):
152+ if conn_name .endswith ("_node_connectivity" ):
153+ map_fn = map_node_indices
141154
142- if "_node_connectivity" in conn_name :
143- # update connectivity vars that index into nodes
144- ds [conn_name ] = xr .DataArray (
145- np .vectorize (node_indices_dict .__getitem__ , otypes = [INT_DTYPE ])(
146- ds [conn_name ].values
147- ),
148- dims = ds [conn_name ].dims ,
149- attrs = ds [conn_name ].attrs ,
150- )
155+ elif conn_name .endswith ("_edge_connectivity" ):
156+ if edge_indices_dict is None :
157+ ds = ds .drop_vars (conn_name )
158+ continue
159+ map_fn = map_edge_indices
151160
152161 elif "_connectivity" in conn_name :
153- # drop any conn that would require re-computation
162+ # anything else we can't remap
154163 ds = ds .drop_vars (conn_name )
164+ continue
165+
166+ else :
167+ # not a connectivity var, skip
168+ continue
169+
170+ # Apply Remapping
171+ ds [conn_name ] = xr .DataArray (
172+ np .vectorize (map_fn , otypes = [INT_DTYPE ])(ds [conn_name ].values ),
173+ dims = ds [conn_name ].dims ,
174+ attrs = ds [conn_name ].attrs ,
175+ )
155176
156177 if inverse_indices :
157178 inverse_indices_ds = xr .Dataset ()
0 commit comments