Skip to content

Commit de5f482

Browse files
committed
enable single stretched grid + callable
1 parent aac8e18 commit de5f482

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

yt_xarray/accessor/accessor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,6 @@ def load_grid(
7272
sel_dict=sel_dict,
7373
sel_dict_type=sel_dict_type,
7474
)
75-
if sel_info.grid_type == _xr_to_yt._GridType.STRETCHED and use_callable:
76-
# why not? this should work now, shouldnt it?
77-
raise NotImplementedError(
78-
"Detected a stretched grid, which is not yet supported for callables, "
79-
"set use_callable=False."
80-
)
8175

8276
if geometry is None:
8377
geometry = self.geometry

yt_xarray/tests/test_accesor.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def test_geom_kwarg(ds_xr):
189189
_ = ds_xr.yt.load_grid(fields=flds, geometry="cartesian")
190190

191191

192-
def test_stretched_grid():
192+
@pytest.mark.parametrize("use_callable", [True, False])
193+
def test_stretched_grid(use_callable):
193194
ds = construct_minimal_ds(
194195
x_stretched=False,
195196
x_name="x",
@@ -199,16 +200,24 @@ def test_stretched_grid():
199200
z_name="z",
200201
)
201202

202-
with pytest.raises(NotImplementedError, match="Detected a stretched grid"):
203-
_ = ds.yt.load_grid(
204-
fields=[
205-
"test_field",
206-
]
207-
)
208-
209-
_ = ds.yt.load_grid(
203+
ds_yt = ds.yt.load_grid(
210204
fields=[
211205
"test_field",
212206
],
213-
use_callable=False,
207+
use_callable=use_callable,
214208
)
209+
210+
# stretched grid will interpolate using raw values as nodal values, so
211+
# end up with n - 1 cells in each dimension
212+
expected_n = np.prod([n - 1 for n in ds.test_field.shape])
213+
ad = ds_yt.all_data()
214+
215+
tst_fld = ad[("stream", "test_field")]
216+
assert tst_fld.size == expected_n
217+
218+
# check that the grid vals match
219+
for dim in "xyz":
220+
dims = ds.coords[dim].values
221+
cell_centers = (dims[:-1] + dims[1:]) / 2
222+
dimvals = np.unique(ad[("index", dim)].d)
223+
assert np.all(dimvals == cell_centers)

0 commit comments

Comments
 (0)