Skip to content

Commit ac525d6

Browse files
philipc2jpan84
andauthored
Update slicing to only update edges when possible (#1256)
* created docstring for azimuthal_mean * wrote azimuthal mean computation in public function * draft azimuthal mean ready to test * fixed typos, return hit count for radial bins * made azimuthal mean more robust to axis ordering * added central coord as attribute in output UxDataArray of azimuthal_mean * run pre-commit * update parameters and set default values to nan * fix edge slicing issue * remove code from other brahnch * fix slicing on edge arrays * use the correct dict --------- Co-authored-by: Joshua Pan <[email protected]>
1 parent e4ab53b commit ac525d6

File tree

2 files changed

+71
-28
lines changed

2 files changed

+71
-28
lines changed

test/test_cross_sections.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,25 @@ def test_latitude_along_arc(self):
198198
out2 = uxgrid.get_faces_at_constant_latitude(lat=25.41)
199199

200200
nt.assert_array_equal(out1, out2)
201+
202+
203+
204+
def test_double_cross_section():
205+
uxgrid = ux.open_grid(quad_hex_grid_path)
206+
207+
# construct edges
208+
sub_lat = uxgrid.cross_section.constant_latitude(0.0)
209+
210+
sub_lat_lon = sub_lat.cross_section.constant_longitude(0.0)
211+
212+
assert "n_edge" not in sub_lat_lon._ds.dims
213+
214+
_ = uxgrid.face_edge_connectivity
215+
_ = uxgrid.edge_node_connectivity
216+
_ = uxgrid.edge_lon
217+
218+
sub_lat = uxgrid.cross_section.constant_latitude(0.0)
219+
220+
sub_lat_lon = sub_lat.cross_section.constant_longitude(0.0)
221+
222+
assert "n_edge" in sub_lat_lon._ds.dims

uxarray/grid/slice.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -94,64 +94,85 @@ def _slice_face_indices(
9494
Indicates whether to store the original grids indices. Passing `True` stores the original face centers,
9595
other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True)
9696
"""
97+
from uxarray.grid import Grid
98+
9799
if inclusive is False:
98100
raise ValueError("Exclusive slicing is not yet supported.")
99101

100-
from uxarray.grid import Grid
101-
102102
ds = grid._ds
103-
104-
indices = np.asarray(indices, dtype=INT_DTYPE)
105-
106-
if indices.ndim == 0:
107-
indices = np.expand_dims(indices, axis=0)
108-
109-
face_indices = indices
103+
face_indices = np.atleast_1d(np.asarray(indices, dtype=INT_DTYPE))
110104

111105
# nodes of each face (inclusive)
112106
node_indices = np.unique(grid.face_node_connectivity.values[face_indices].ravel())
113107
node_indices = node_indices[node_indices != INT_FILL_VALUE]
114108

115-
# index original dataset to obtain a 'subgrid'
109+
# Index Node and Face variables
116110
ds = ds.isel(n_node=node_indices)
117111
ds = ds.isel(n_face=face_indices)
118112

119-
# Only slice edge dimension if we already have the connectivity
120-
if "face_edge_connectivity" in grid._ds:
113+
# Only slice edge dimension if we have the face edge connectivity
114+
if "face_edge_connectivity" in ds:
121115
edge_indices = np.unique(
122116
grid.face_edge_connectivity.values[face_indices].ravel()
123117
)
124118
edge_indices = edge_indices[edge_indices != INT_FILL_VALUE]
125119
ds = ds.isel(n_edge=edge_indices)
126120
ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"])
121+
# Otherwise, drop any edge variables
127122
else:
123+
if "n_edge" in ds.dims:
124+
ds = ds.drop_dims(["n_edge"])
128125
edge_indices = None
129126

130127
ds["subgrid_node_indices"] = xr.DataArray(node_indices, dims=["n_node"])
131128
ds["subgrid_face_indices"] = xr.DataArray(face_indices, dims=["n_face"])
132129

133-
# mapping to update existing connectivity
134-
node_indices_dict = {
135-
key: val for key, val in zip(node_indices, np.arange(0, len(node_indices)))
136-
}
130+
# Construct updated Node Index Map
131+
node_indices_dict = {orig: new for new, orig in enumerate(node_indices)}
137132
node_indices_dict[INT_FILL_VALUE] = INT_FILL_VALUE
138133

139-
for conn_name in grid._ds.data_vars:
140-
# update or drop connectivity variables to correctly point to the new index of each element
134+
# Construct updated Edge Index Map
135+
if edge_indices is not None:
136+
edge_indices_dict = {orig: new for new, orig in enumerate(edge_indices)}
137+
edge_indices_dict[INT_FILL_VALUE] = INT_FILL_VALUE
138+
else:
139+
edge_indices_dict = None
140+
141+
def map_node_indices(i):
142+
return node_indices_dict.get(i, INT_FILL_VALUE)
143+
144+
if edge_indices is not None:
145+
146+
def map_edge_indices(i):
147+
return edge_indices_dict.get(i, INT_FILL_VALUE)
148+
else:
149+
map_edge_indices = None
150+
151+
for conn_name in list(ds.data_vars):
152+
if conn_name.endswith("_node_connectivity"):
153+
map_fn = map_node_indices
141154

142-
if "_node_connectivity" in conn_name:
143-
# update connectivity vars that index into nodes
144-
ds[conn_name] = xr.DataArray(
145-
np.vectorize(node_indices_dict.__getitem__, otypes=[INT_DTYPE])(
146-
ds[conn_name].values
147-
),
148-
dims=ds[conn_name].dims,
149-
attrs=ds[conn_name].attrs,
150-
)
155+
elif conn_name.endswith("_edge_connectivity"):
156+
if edge_indices_dict is None:
157+
ds = ds.drop_vars(conn_name)
158+
continue
159+
map_fn = map_edge_indices
151160

152161
elif "_connectivity" in conn_name:
153-
# drop any conn that would require re-computation
162+
# anything else we can't remap
154163
ds = ds.drop_vars(conn_name)
164+
continue
165+
166+
else:
167+
# not a connectivity var, skip
168+
continue
169+
170+
# Apply Remapping
171+
ds[conn_name] = xr.DataArray(
172+
np.vectorize(map_fn, otypes=[INT_DTYPE])(ds[conn_name].values),
173+
dims=ds[conn_name].dims,
174+
attrs=ds[conn_name].attrs,
175+
)
155176

156177
if inverse_indices:
157178
inverse_indices_ds = xr.Dataset()

0 commit comments

Comments
 (0)