11import numpy as np
2+ import pytest
3+ import yt
24
35import yt_xarray # noqa: F401
46from yt_xarray ._utilities import construct_minimal_ds
57
68
9+ @pytest .fixture ()
10+ def ds_xr ():
11+ # a base xarray ds to be used in various places.
12+ tfield = "a_new_field"
13+ n_x = 3
14+ n_y = 4
15+ n_z = 5
16+ ds = construct_minimal_ds (
17+ field_name = tfield ,
18+ n_fields = 3 ,
19+ n_x = n_x ,
20+ n_y = n_y ,
21+ n_z = n_z ,
22+ z_name = "depth" ,
23+ coord_order = ["z" , "y" , "x" ],
24+ )
25+ return ds
26+
27+
728def test_accessor ():
829
930 tfield = "a_new_field"
@@ -69,33 +90,23 @@ def test_bbox():
6990 # the test dataset.
7091
7192
72- def test_load_uniform_grid ():
93+ def test_load_uniform_grid (ds_xr ):
7394
74- tfield = "a_new_field"
75- n_x = 3
76- n_y = 4
77- n_z = 5
78- ds = construct_minimal_ds (
79- field_name = tfield ,
80- n_fields = 3 ,
81- n_x = n_x ,
82- n_y = n_y ,
83- n_z = n_z ,
84- z_name = "depth" ,
85- coord_order = ["z" , "y" , "x" ],
86- )
87-
88- flds = [tfield + "_0" , tfield + "_1" ]
89- ds_yt = ds .yt .load_uniform_grid (flds )
95+ flds = ["a_new_field_0" , "a_new_field_1" ]
96+ ds_yt = ds_xr .yt .load_uniform_grid (flds )
9097 assert ds_yt .coordinates .name == "internal_geographic"
9198 expected_field_list = [("stream" , f ) for f in flds ]
9299 assert all ([f in expected_field_list ] for f in ds_yt .field_list )
93100
94- ds_yt = ds .yt .ds # should generate a ds with all fields
95- flds = [tfield + "_0" , tfield + "_1" , tfield + "_2" ]
101+ ds_yt = ds_xr .yt .load_uniform_grid () # should generate a ds with all fields
102+ flds = flds + [
103+ "a_new_field_2" ,
104+ ]
96105 expected_field_list = [("stream" , f ) for f in flds ]
97106 assert all ([f in expected_field_list ] for f in ds_yt .field_list )
98107
108+ tfield = "nice_field"
109+ n_x , n_y , n_z = (7 , 5 , 17 )
99110 ds = construct_minimal_ds (
100111 field_name = tfield ,
101112 n_fields = 3 ,
@@ -105,7 +116,7 @@ def test_load_uniform_grid():
105116 z_name = "altitude" ,
106117 coord_order = ["z" , "y" , "x" ],
107118 )
108- ds_yt = ds .yt .ds
119+ ds_yt = ds .yt .load_uniform_grid ()
109120 assert ds_yt .coordinates .name == "geographic"
110121 assert all ([f in expected_field_list ] for f in ds_yt .field_list )
111122
@@ -120,6 +131,35 @@ def test_load_uniform_grid():
120131 y_name = "y" ,
121132 coord_order = ["z" , "y" , "x" ],
122133 )
134+ flds = [
135+ tfield + "_0" ,
136+ ]
123137 ds_yt = ds .yt .load_uniform_grid (flds , length_unit = "km" )
124138 assert ds_yt .coordinates .name == "cartesian"
125139 assert all ([f in expected_field_list ] for f in ds_yt .field_list )
140+
141+
142+ @pytest .mark .skipif (
143+ yt .__version__ .startswith ("4.1" ) is False , reason = "requires yt>=4.1.0"
144+ )
145+ def test_load_grid_from_callable (ds_xr ):
146+ ds = ds_xr .yt .load_grid_from_callable ()
147+ flds = list (ds_xr .data_vars )
148+ for fld in flds :
149+ assert ("stream" , fld ) in ds .field_list
150+
151+ f = ds .all_data ()[flds [0 ]]
152+ assert len (f ) == ds_xr .data_vars [flds [0 ]].size
153+
154+
155+ @pytest .mark .skipif (
156+ yt .__version__ .startswith ("4.1" ) is False , reason = "requires yt>=4.1.0"
157+ )
158+ def test_yt_ds_attr (ds_xr ):
159+ ds = ds_xr .yt .ds ()
160+ flds = list (ds_xr .data_vars )
161+ for fld in flds :
162+ assert ("stream" , fld ) in ds .field_list
163+
164+ f = ds .all_data ()[flds [0 ]]
165+ assert len (f ) == ds_xr .data_vars [flds [0 ]].size
0 commit comments