Skip to content

Commit bce15ce

Browse files
committed
apply the reversal with xr commands when reading
1 parent 26a2ace commit bce15ce

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

yt_xarray/accessor/_readers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ def _reader(grid, field_name):
8181
# load into memory (if its not) as xr DataArray
8282
datavals = datavals.load()
8383

84+
# reverse axis ordering if needed
85+
for axname in sel_info.reverse_axis_names:
86+
dimvals = getattr(datavals, axname)
87+
datavals = datavals.sel({axname: dimvals[::-1]})
88+
8489
if interp_required:
8590
# interpolate from nodes to cell centers across all remaining dims
8691
datavals = _xr_to_yt._interpolate_to_cell_centers(datavals)
8792

88-
# final flips to account for all the index reversing
89-
for idim in range(sel_info.ndims):
90-
if sel_info.reverse_axis[idim]:
91-
datavals = np.flip(datavals, axis=idim)
92-
9393
# return the plain values
9494
vals = datavals.values.astype(np.float64)
9595
if sel_info.ndims == 2:

yt_xarray/accessor/_xr_to_yt.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _process_selection(self, xr_ds):
132132
cell_widths = [] # cell widths after selection
133133
grid_type = _GridType.UNIFORM # start with uniform assumption
134134
reverse_axis = [] # axes must be positive-monitonic for yt
135+
reverse_axis_names = []
135136
global_dims = [] # the global shape
136137
for c in full_coords:
137138
coord_da = getattr(xr_ds, c) # the full coordinate data array
@@ -140,6 +141,8 @@ def _process_selection(self, xr_ds):
140141
if coord_da.size > 1:
141142
rev_ax = coord_da[1] <= coord_da[0]
142143
reverse_axis.append(bool(rev_ax.values))
144+
if rev_ax:
145+
reverse_axis_names.append(c)
143146

144147
# store the global ranges
145148
global_dims.append(coord_da.size)
@@ -204,6 +207,7 @@ def _process_selection(self, xr_ds):
204207
self.grid_type = grid_type
205208
self.cell_widths = cell_widths
206209
self.reverse_axis = reverse_axis
210+
self.reverse_axis_names = reverse_axis_names
207211
self.global_dims = np.array(global_dims)
208212
# self.coord_selected_arrays = coord_selected_arrays
209213

@@ -250,9 +254,15 @@ def _validate_fields(self, xr_ds, fields: List[str]) -> List[str]:
250254

251255
def select_from_xr(self, xr_ds, field):
252256
if self.sel_dict_type == "isel":
253-
return xr_ds[field].isel(self.sel_dict)
257+
vars = xr_ds[field].isel(self.sel_dict)
254258
else:
255-
return xr_ds[field].sel(self.sel_dict)
259+
vars = xr_ds[field].sel(self.sel_dict)
260+
261+
for axname in self.reverse_axis_names:
262+
dimvals = getattr(vars, axname)
263+
vars = vars.sel({axname: dimvals[::-1]})
264+
265+
return vars
256266

257267
def interp_validation(self, geometry):
258268
# checks if yt will need to interpolate to cell center
@@ -439,12 +449,6 @@ def _load_full_field_from_xr(
439449

440450
if interp_required:
441451
vals = _interpolate_to_cell_centers(vals)
442-
if any(sel_info.reverse_axis):
443-
# if any dims are in decreaseing order, flip that axis
444-
# after reading in the data
445-
for idim, flip_it in enumerate(sel_info.reverse_axis):
446-
if flip_it:
447-
vals = np.flip(vals, axis=idim)
448452

449453
vals = vals.values.astype(np.float64)
450454
if sel_info.ndims == 2:

0 commit comments

Comments
 (0)