Skip to content

Commit a27ee62

Browse files
committed
handling revesred dimension arrays
1 parent 099433e commit a27ee62

File tree

4 files changed

+105
-12
lines changed

4 files changed

+105
-12
lines changed

tests/test_xr_to_yt.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,53 @@ def test_add_3rd_axis_name(yt_geom):
479479

480480
with pytest.raises(ValueError, match="Unsupported geometry type"):
481481
_ = xr2yt._add_3rd_axis_name("bad_geometry", expected[:-1])
482+
483+
484+
@pytest.mark.parametrize(
485+
"stretched,use_callable,chunksizes",
486+
[
487+
(True, False, None),
488+
(False, False, None),
489+
(False, True, None),
490+
(False, True, 10),
491+
(False, False, 10),
492+
],
493+
)
494+
def test_reversed_axis(stretched, use_callable, chunksizes):
495+
# tests for when the incoming data is not positive-monotonic
496+
497+
ds = construct_minimal_ds(
498+
min_x=1,
499+
max_x=359,
500+
min_z=50,
501+
max_z=650,
502+
min_y=89,
503+
max_y=-89,
504+
n_x=50,
505+
n_y=100,
506+
n_z=30,
507+
z_stretched=stretched,
508+
)
509+
yt_ds = ds.yt.load_grid(use_callable=use_callable, chunksizes=chunksizes)
510+
511+
if stretched:
512+
grid_obj = yt_ds.index.grids[0]
513+
ax_id = yt_ds.coordinates.axis_id["latitude"]
514+
assert np.all(grid_obj.cell_widths[ax_id] > 0)
515+
516+
slc = yt_ds.slice(
517+
yt_ds.coordinates.axis_id["depth"],
518+
yt_ds.domain_center[yt_ds.coordinates.axis_id["depth"]],
519+
center=yt_ds.domain_center,
520+
)
521+
pdy_lats = slc._generate_container_field("pdy")
522+
assert np.all(pdy_lats > 0)
523+
524+
vals = yt_ds.coordinates.pixelize(
525+
0,
526+
slc,
527+
("stream", "test_field"),
528+
yt_ds.arr([1, 359, -89, 89], "code_length"),
529+
(400, 400),
530+
)
531+
assert np.all(np.isfinite(vals))

yt_xarray/accessor/_readers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _reader(grid, field_name):
2727

2828
si = grid.get_global_startindex() + gsi
2929
ei = si + grid.ActiveDimensions
30+
global_dims = sel_info.global_dims.copy()
3031
if interp_required:
3132
# if interpolating, si and ei must be node indices so
3233
# we offset by an additional element
@@ -37,7 +38,22 @@ def _reader(grid, field_name):
3738
c_list = sel_info.selected_coords # the xarray coord names
3839
i_select_dict = {}
3940
for idim in range(sel_info.ndims):
40-
i_select_dict[c_list[idim]] = slice(si[idim], ei[idim])
41+
if sel_info.reverse_axis[idim]:
42+
# the xarray axis is in negative ordering. a yt index of
43+
# 0 should point to the maximum index in xarray
44+
si_idim = global_dims[idim] - si[idim]
45+
ei_idim = global_dims[idim] - ei[idim]
46+
# when reverse slicing, the final index will not be included...
47+
# if the end index is within bounds, just bump it by one more,
48+
# otherwise if end index is already 0, just pass in None to
49+
# slice
50+
if ei_idim > 0:
51+
ei_idim = ei_idim - 1
52+
elif ei_idim == 0:
53+
ei_idim = None
54+
i_select_dict[c_list[idim]] = slice(si_idim, ei_idim, -1)
55+
else:
56+
i_select_dict[c_list[idim]] = slice(si[idim], ei[idim])
4157

4258
# set any of the initial selections that will reduce the
4359
# dimensionality or size of the full DataArray

yt_xarray/accessor/_xr_to_yt.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
self.ndims: int = None
5555
self.grid_type = None # one of _GridType members
5656
self.cell_widths: list = None
57+
self.global_dims: list = None
5758
self._process_selection(xr_ds)
5859

5960
self.yt_coord_names = _convert_to_yt_internal_coords(self.selected_coords)
@@ -130,9 +131,13 @@ def _process_selection(self, xr_ds):
130131
starting_indices = [] # global starting index
131132
cell_widths = [] # cell widths after selection
132133
grid_type = _GridType.UNIFORM # start with uniform assumption
134+
reverse_axis = [] # axes must be positive-monitonic for yt
135+
global_dims = []
133136
for c in full_coords:
134137
coord_da = getattr(xr_ds, c) # the full coordinate data array
135-
138+
rev_ax = coord_da[1] <= coord_da[0]
139+
reverse_axis.append(bool(rev_ax.values))
140+
global_dims.append(coord_da.size)
136141
# store the global ranges
137142
global_min = float(coord_da.min().values)
138143
global_max = float(coord_da.max().values)
@@ -146,6 +151,8 @@ def _process_selection(self, xr_ds):
146151

147152
sel_or_isel = getattr(coord_da, self.sel_dict_type)
148153
coord_vals = sel_or_isel(coord_select).values.astype(np.float64)
154+
if reverse_axis[-1]:
155+
coord_vals = coord_vals[::-1]
149156
is_time_dim = _check_for_time(c, coord_vals)
150157

151158
if coord_vals.size > 1:
@@ -188,6 +195,8 @@ def _process_selection(self, xr_ds):
188195
self.selected_time = time
189196
self.grid_type = grid_type
190197
self.cell_widths = cell_widths
198+
self.reverse_axis = reverse_axis
199+
self.global_dims = np.array(global_dims)
191200
# self.coord_selected_arrays = coord_selected_arrays
192201

193202
# set the yt grid dictionary

yt_xarray/accessor/accessor.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def _load_single_grid(
247247
if interp_required:
248248
vals = _xr_to_yt._interpolate_to_cell_centers(vals)
249249
vals = vals.values.astype(np.float64)
250+
for ax, rev_ax in enumerate(sel_info.reverse_axis):
251+
if rev_ax:
252+
vals = np.flip(vals, axis=ax)
250253
if sel_info.ndims == 2:
251254
vals = np.expand_dims(vals, axis=-1)
252255
data[field] = (vals, units)
@@ -363,15 +366,28 @@ def _load_chunked_grid(
363366
)
364367

365368
c = cnames[idim]
366-
le_0 = ds_xr[fld].coords[c].isel({c: si_0}).values
367-
if interp_required is False:
368-
# the left edges get bumped left since we are reading values
369-
# again.
370-
le_0 = le_0 - dxyz[idim] / 2.0
369+
rev_ax = sel_info.reverse_axis[idim]
370+
if rev_ax is False:
371371

372-
# bbox value below already accounts for interp_required, no need to shift
373-
max_val = bbox[idim, 1]
374-
re_0 = np.concatenate([le_0[1:], [max_val]])
372+
le_0 = ds_xr[fld].coords[c].isel({c: si_0}).values
373+
374+
if interp_required is False:
375+
# the left edges get bumped left since we are reading values
376+
# again.
377+
le_0 = le_0 - dxyz[idim] / 2.0
378+
379+
# bbox value below already accounts for interp_required, no need to shift
380+
max_val = bbox[idim, 1]
381+
re_0 = np.concatenate([le_0[1:], [max_val]])
382+
383+
else:
384+
re_0 = ds_xr[fld].coords[c].isel({c: si_0[::-1]}).values
385+
if interp_required is False:
386+
# the left edges get bumped left since we are reading values
387+
# again.
388+
re_0 = re_0 - dxyz[idim] / 2.0
389+
min_val = bbox[idim, 0]
390+
le_0 = np.concatenate([[min_val], re_0[:-1]])
375391

376392
# sizes also already account for interp_required
377393
subgrid_size = ei_0 - si_0
@@ -411,6 +427,10 @@ def _load_chunked_grid(
411427
vals = sel_info.select_from_xr(ds_xr, field).load()
412428
if interp_required:
413429
vals = _xr_to_yt._interpolate_to_cell_centers(vals)
430+
if any(sel_info.reverse_axis):
431+
for idim, flip_it in enumerate(sel_info.reverse_axis):
432+
if flip_it:
433+
vals = np.flip(vals, axis=idim)
414434
full_field_vals[field] = vals.values.astype(np.float64)
415435

416436
for igrid in range(n_grids):
@@ -425,10 +445,8 @@ def _load_chunked_grid(
425445
if use_callable:
426446
gdict[field] = (reader, units)
427447
else:
428-
# NO these values need to be chunked too.
429448
si = subgrid_start[igrid]
430449
ei = subgrid_end[igrid]
431-
# this needs to be fixed for two2 fields...
432450
gridvals = full_field_vals[field][
433451
si[0] : ei[0], si[1] : ei[1], si[2] : ei[2]
434452
]

0 commit comments

Comments
 (0)