Skip to content

Commit 9e1f6a3

Browse files
committed
refactor the index reversing
1 parent a27ee62 commit 9e1f6a3

File tree

5 files changed

+120
-80
lines changed

5 files changed

+120
-80
lines changed

tests/test_xr_to_yt.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -481,19 +481,23 @@ def test_add_3rd_axis_name(yt_geom):
481481
_ = xr2yt._add_3rd_axis_name("bad_geometry", expected[:-1])
482482

483483

484-
@pytest.mark.parametrize(
485-
"stretched,use_callable,chunksizes",
486-
[
487-
(True, False, None),
488-
(False, False, None),
489-
(False, True, None),
490-
(False, True, 10),
491-
(False, False, 10),
492-
],
493-
)
494-
def test_reversed_axis(stretched, use_callable, chunksizes):
495-
# tests for when the incoming data is not positive-monotonic
484+
def _get_pixelized_slice(yt_ds):
485+
slc = yt_ds.slice(
486+
yt_ds.coordinates.axis_id["depth"],
487+
yt_ds.domain_center[yt_ds.coordinates.axis_id["depth"]],
488+
center=yt_ds.domain_center,
489+
)
490+
vals = yt_ds.coordinates.pixelize(
491+
0,
492+
slc,
493+
("stream", "test_field"),
494+
yt_ds.arr([1, 359, -89, 89], "code_length"),
495+
(400, 400),
496+
)
497+
return slc, vals
496498

499+
500+
def _get_ds_for_reverse_tests(stretched, use_callable, chunksizes):
497501
ds = construct_minimal_ds(
498502
min_x=1,
499503
max_x=359,
@@ -505,27 +509,34 @@ def test_reversed_axis(stretched, use_callable, chunksizes):
505509
n_y=100,
506510
n_z=30,
507511
z_stretched=stretched,
512+
npseed=True,
508513
)
509514
yt_ds = ds.yt.load_grid(use_callable=use_callable, chunksizes=chunksizes)
515+
return yt_ds
516+
517+
518+
@pytest.mark.parametrize(
519+
"stretched,use_callable,chunksizes",
520+
[
521+
(True, False, None),
522+
(False, False, None),
523+
(False, True, None),
524+
(False, True, 20),
525+
(False, False, 20),
526+
],
527+
)
528+
def test_reversed_axis(stretched, use_callable, chunksizes):
529+
# tests for when the incoming data is not positive-monotonic
530+
531+
yt_ds = _get_ds_for_reverse_tests(stretched, use_callable, chunksizes)
510532

511533
if stretched:
512534
grid_obj = yt_ds.index.grids[0]
513535
ax_id = yt_ds.coordinates.axis_id["latitude"]
514536
assert np.all(grid_obj.cell_widths[ax_id] > 0)
515537

516-
slc = yt_ds.slice(
517-
yt_ds.coordinates.axis_id["depth"],
518-
yt_ds.domain_center[yt_ds.coordinates.axis_id["depth"]],
519-
center=yt_ds.domain_center,
520-
)
538+
slc, vals = _get_pixelized_slice(yt_ds)
539+
521540
pdy_lats = slc._generate_container_field("pdy")
522541
assert np.all(pdy_lats > 0)
523-
524-
vals = yt_ds.coordinates.pixelize(
525-
0,
526-
slc,
527-
("stream", "test_field"),
528-
yt_ds.arr([1, 359, -89, 89], "code_length"),
529-
(400, 400),
530-
)
531542
assert np.all(np.isfinite(vals))

yt_xarray/accessor/_readers.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,48 @@ def _reader(grid, field_name):
1818
# grid: a yt grid object
1919
_, fname = field_name
2020

21-
# this global start index accounts for indexing after any
21+
# first get the internal yt index ranges
22+
si = grid.get_global_startindex()
23+
ei = si + grid.ActiveDimensions
24+
25+
# convert to the xarray indexing in 3 steps
26+
# 1. account for dimension direction
27+
# 2. account for any prior xarray subselections
28+
# 3. account for interpolation requirements
29+
30+
# step 1 (if axis has been reversed, node ordering is also reversed)
31+
for idim in range(sel_info.ndims):
32+
if sel_info.reverse_axis[idim]:
33+
# note that the si, ei are exchanged!
34+
si0 = si.copy()
35+
ei0 = ei.copy()
36+
si[idim] = sel_info.global_dims[idim] - ei0[idim]
37+
ei[idim] = sel_info.global_dims[idim] - si0[idim]
38+
39+
# step 2: this global start index accounts for indexing after any
2240
# subselections on the xarray DataArray are made. might
2341
# be a way to properly set this with yt grid objects.
2442
gsi = sel_info.starting_indices
2543
if sel_info.ndims == 2:
2644
gsi = np.append(gsi, 0)
45+
si = si + gsi
46+
ei = ei + gsi
2747

28-
si = grid.get_global_startindex() + gsi
29-
ei = si + grid.ActiveDimensions
30-
global_dims = sel_info.global_dims.copy()
48+
# step 3: if we are interpolating to cell centers, grab some extra nodes
3149
if interp_required:
32-
# if interpolating, si and ei must be node indices so
33-
# we offset by an additional element
34-
cell_to_node_offset = np.ones((3,), dtype=int)
35-
ei = ei + cell_to_node_offset
50+
for idim in range(sel_info.ndims):
51+
if sel_info.reverse_axis[idim]:
52+
si[idim] = si[idim] - 1
53+
else:
54+
ei[idim] = ei[idim] + 1
55+
56+
# now we can select the data (still accounting for any dimension reversal)
3657

37-
# build the index-selector for our yt grid object
58+
# build the index-selector for each dimension
3859
c_list = sel_info.selected_coords # the xarray coord names
3960
i_select_dict = {}
4061
for idim in range(sel_info.ndims):
41-
if sel_info.reverse_axis[idim]:
42-
# the xarray axis is in negative ordering. a yt index of
43-
# 0 should point to the maximum index in xarray
44-
si_idim = global_dims[idim] - si[idim]
45-
ei_idim = global_dims[idim] - ei[idim]
46-
# when reverse slicing, the final index will not be included...
47-
# if the end index is within bounds, just bump it by one more,
48-
# otherwise if end index is already 0, just pass in None to
49-
# slice
50-
if ei_idim > 0:
51-
ei_idim = ei_idim - 1
52-
elif ei_idim == 0:
53-
ei_idim = None
54-
i_select_dict[c_list[idim]] = slice(si_idim, ei_idim, -1)
55-
else:
56-
i_select_dict[c_list[idim]] = slice(si[idim], ei[idim])
62+
i_select_dict[c_list[idim]] = slice(si[idim], ei[idim])
5763

5864
# set any of the initial selections that will reduce the
5965
# dimensionality or size of the full DataArray
@@ -79,6 +85,11 @@ def _reader(grid, field_name):
7985
# interpolate from nodes to cell centers across all remaining dims
8086
datavals = _xr_to_yt._interpolate_to_cell_centers(datavals)
8187

88+
# final flips to account for all the index reversing
89+
for idim in range(sel_info.ndims):
90+
if sel_info.reverse_axis[idim]:
91+
datavals = np.flip(datavals, axis=idim)
92+
8293
# return the plain values
8394
vals = datavals.values.astype(np.float64)
8495
if sel_info.ndims == 2:

yt_xarray/accessor/_xr_to_yt.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,36 +132,43 @@ def _process_selection(self, xr_ds):
132132
cell_widths = [] # cell widths after selection
133133
grid_type = _GridType.UNIFORM # start with uniform assumption
134134
reverse_axis = [] # axes must be positive-monitonic for yt
135-
global_dims = []
135+
global_dims = [] # the global shape
136136
for c in full_coords:
137137
coord_da = getattr(xr_ds, c) # the full coordinate data array
138+
139+
# check if coordinate values are increasing
138140
rev_ax = coord_da[1] <= coord_da[0]
139141
reverse_axis.append(bool(rev_ax.values))
140-
global_dims.append(coord_da.size)
142+
141143
# store the global ranges
144+
global_dims.append(coord_da.size)
142145
global_min = float(coord_da.min().values)
143146
global_max = float(coord_da.max().values)
144147
full_dimranges.append([global_min, global_max])
145148

146-
si = 0 # starting index
149+
si = 0 # starting xarray-index for any pre-selections from user
147150
coord_select = {} # the selection dictionary just for this coordinate
148151
if c in self.sel_dict:
149152
coord_select[c] = self.sel_dict[c]
150153
si = self._find_starting_index(c, coord_da, coord_select)
151154

155+
# apply any selections and extract coordinates
152156
sel_or_isel = getattr(coord_da, self.sel_dict_type)
153157
coord_vals = sel_or_isel(coord_select).values.astype(np.float64)
154-
if reverse_axis[-1]:
155-
coord_vals = coord_vals[::-1]
156158
is_time_dim = _check_for_time(c, coord_vals)
157159

158160
if coord_vals.size > 1:
161+
162+
# not positive-monotonic? reverse it for cell width calculations
163+
# changes to indexing are accounted for when extracting data.
164+
if reverse_axis[-1]:
165+
coord_vals = coord_vals[::-1]
166+
159167
cell_widths.append(coord_vals[1:] - coord_vals[:-1])
160168
dimranges.append([coord_vals.min(), coord_vals.max()])
161169
n_edges.append(coord_vals.size)
162170
n_cells.append(coord_vals.size - 1)
163171
coord_list.append(c)
164-
# coord_selected_arrays[c] = coord_vals
165172
starting_indices.append(si)
166173

167174
if is_time_dim:
@@ -422,3 +429,23 @@ def _interpolate_to_cell_centers(data: xr.DataArray):
422429
dimvals = data.coords[dim].values
423430
interp_dict[dim] = (dimvals[1:] + dimvals[:-1]) / 2.0
424431
return data.interp(interp_dict)
432+
433+
434+
def _load_full_field_from_xr(
435+
ds_xr, field: str, sel_info: Selection, interp_required: bool = False
436+
):
437+
vals = sel_info.select_from_xr(ds_xr, field).load()
438+
439+
if interp_required:
440+
vals = _interpolate_to_cell_centers(vals)
441+
if any(sel_info.reverse_axis):
442+
# if any dims are in decreaseing order, flip that axis
443+
# after reading in the data
444+
for idim, flip_it in enumerate(sel_info.reverse_axis):
445+
if flip_it:
446+
vals = np.flip(vals, axis=idim)
447+
448+
vals = vals.values.astype(np.float64)
449+
if sel_info.ndims == 2:
450+
vals = np.expand_dims(vals, axis=-1)
451+
return vals

yt_xarray/accessor/accessor.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from yt_xarray.accessor import _xr_to_yt
1010
from yt_xarray.accessor._readers import _get_xarray_reader
11+
from yt_xarray.accessor._xr_to_yt import _load_full_field_from_xr
1112
from yt_xarray.utilities.logging import ytxr_log
1213

1314

@@ -243,15 +244,9 @@ def _load_single_grid(
243244
if use_callable:
244245
data[field] = (reader, units)
245246
else:
246-
vals = sel_info.select_from_xr(ds_xr, field).load()
247-
if interp_required:
248-
vals = _xr_to_yt._interpolate_to_cell_centers(vals)
249-
vals = vals.values.astype(np.float64)
250-
for ax, rev_ax in enumerate(sel_info.reverse_axis):
251-
if rev_ax:
252-
vals = np.flip(vals, axis=ax)
253-
if sel_info.ndims == 2:
254-
vals = np.expand_dims(vals, axis=-1)
247+
vals = _load_full_field_from_xr(
248+
ds_xr, field, sel_info, interp_required=interp_required
249+
)
255250
data[field] = (vals, units)
256251

257252
if sel_info.ndims == 2:
@@ -368,23 +363,19 @@ def _load_chunked_grid(
368363
c = cnames[idim]
369364
rev_ax = sel_info.reverse_axis[idim]
370365
if rev_ax is False:
371-
372366
le_0 = ds_xr[fld].coords[c].isel({c: si_0}).values
373-
374367
if interp_required is False:
375-
# the left edges get bumped left since we are reading values
376-
# again.
368+
# move the edges so the node is now a cell center
377369
le_0 = le_0 - dxyz[idim] / 2.0
378370

379-
# bbox value below already accounts for interp_required, no need to shift
371+
# bbox value below already accounts for interp_required
380372
max_val = bbox[idim, 1]
381373
re_0 = np.concatenate([le_0[1:], [max_val]])
382374

383375
else:
384376
re_0 = ds_xr[fld].coords[c].isel({c: si_0[::-1]}).values
385377
if interp_required is False:
386-
# the left edges get bumped left since we are reading values
387-
# again.
378+
# move the edges so the node is now a cell center
388379
re_0 = re_0 - dxyz[idim] / 2.0
389380
min_val = bbox[idim, 0]
390381
le_0 = np.concatenate([[min_val], re_0[:-1]])
@@ -424,14 +415,10 @@ def _load_chunked_grid(
424415
if use_callable is False:
425416
full_field_vals = {}
426417
for field in fields:
427-
vals = sel_info.select_from_xr(ds_xr, field).load()
428-
if interp_required:
429-
vals = _xr_to_yt._interpolate_to_cell_centers(vals)
430-
if any(sel_info.reverse_axis):
431-
for idim, flip_it in enumerate(sel_info.reverse_axis):
432-
if flip_it:
433-
vals = np.flip(vals, axis=idim)
434-
full_field_vals[field] = vals.values.astype(np.float64)
418+
vals = _load_full_field_from_xr(
419+
ds_xr, field, sel_info, interp_required=interp_required
420+
)
421+
full_field_vals[field] = vals
435422

436423
for igrid in range(n_grids):
437424
gdict = {

yt_xarray/utilities/_utilities.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def construct_minimal_ds(
2929
n_fields: int = 1,
3030
coord_order: Optional[Tuple[str, str, str]] = None,
3131
dtype: str = "float64",
32+
npseed: bool = False,
3233
) -> xr.Dataset:
3334

3435
if coord_order is None:
@@ -78,6 +79,9 @@ def construct_minimal_ds(
7879
coord_order_rn += (cdict[cname],)
7980
var_shape += (n,)
8081

82+
if npseed:
83+
np.random.seed(0)
84+
8185
vals = np.random.random(var_shape).astype(dtype_to_use)
8286
if n_fields > 1:
8387
data_vars = {}

0 commit comments

Comments
 (0)