Skip to content

Commit 0aa69af

Browse files
authored
refactor: prefer backend over gt4py_backend (#170)
On `Quantities`, we currently allow both `backend` and `gt4py_backend` where the later is deprecated and about to be removed. See NOAA-GFDL/NDSL#312 and NOAA-GFDL/NDSL#314 for context. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
1 parent 421ecb9 commit 0aa69af

File tree

5 files changed

+13
-19
lines changed

5 files changed

+13
-19
lines changed

examples/notebooks/functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def init_quantity(
106106
units=units,
107107
origin=(nhalo, nhalo, 0)[:skip_z],
108108
extent=(nx, ny, nz)[:skip_z],
109-
gt4py_backend=backend,
109+
backend=backend,
110110
)
111111

112112
if grid == VariableGrid.CellCorners:
@@ -116,7 +116,7 @@ def init_quantity(
116116
units=units,
117117
origin=(nhalo, nhalo, 0)[:skip_z],
118118
extent=(nx + 1, ny + 1, nz)[:skip_z],
119-
gt4py_backend=backend,
119+
backend=backend,
120120
)
121121

122122
elif grid == VariableGrid.StaggeredInX:
@@ -126,7 +126,7 @@ def init_quantity(
126126
units=units,
127127
origin=(nhalo, nhalo, 0)[:skip_z],
128128
extent=(nx + 1, ny, nz)[:skip_z],
129-
gt4py_backend=backend,
129+
backend=backend,
130130
)
131131

132132
elif grid == VariableGrid.StaggeredInY:
@@ -136,7 +136,7 @@ def init_quantity(
136136
units=units,
137137
origin=(nhalo, nhalo, 0)[:skip_z],
138138
extent=(nx, ny + 1, nz)[:skip_z],
139-
gt4py_backend=backend,
139+
backend=backend,
140140
)
141141

142142
return variable

examples/notebooks/grid_generation.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,15 @@
267267
" units=\"degrees\",\n",
268268
" origin=(nhalo, nhalo),\n",
269269
" extent=(nx + 1, ny + 1),\n",
270-
" gt4py_backend=backend,\n",
270+
" backend=backend,\n",
271271
")\n",
272272
"lat = Quantity(\n",
273273
" metric_terms.lat.data * 180 / np.pi,\n",
274274
" dims=(\"x_interface\", \"y_interface\"),\n",
275275
" units=\"degrees\",\n",
276276
" origin=(nhalo, nhalo),\n",
277277
" extent=(nx + 1, ny + 1),\n",
278-
" gt4py_backend=backend,\n",
278+
" backend=backend,\n",
279279
")\n",
280280
"\n",
281281
"# gather the distributed fields into a global field on the root rank\n",
@@ -357,7 +357,7 @@
357357
" units=\"m2\",\n",
358358
" origin=(nhalo, nhalo),\n",
359359
" extent=(nx, ny),\n",
360-
" gt4py_backend=backend,\n",
360+
" backend=backend,\n",
361361
")\n",
362362
"\n",
363363
"# rescale to 10^3 km2\n",

pace/grid.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,8 @@ def get_grid(
181181
quantity_factory: QuantityFactory,
182182
communicator: Communicator,
183183
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
184-
backend = quantity_factory.zeros(
185-
dims=[X_DIM, Y_DIM], units="unknown"
186-
).gt4py_backend
187-
188184
ndsl_log.info("Using serialized grid data")
189-
grid = self._get_serialized_grid(communicator, backend)
185+
grid = self._get_serialized_grid(communicator, quantity_factory.backend)
190186
grid_data = grid.grid_data
191187
driver_grid_data = grid.driver_grid_data
192188
damping_coefficients = grid.damping_coefficients

pace/initialization.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,9 @@ def get_driver_state(
286286
grid_data: GridData,
287287
schemes: List[PHYSICS_PACKAGES],
288288
) -> DriverState:
289-
backend = quantity_factory.zeros(
290-
dims=[X_DIM, Y_DIM], units="unknown"
291-
).gt4py_backend
292-
293-
dycore_state = self._initialize_dycore_state(communicator, backend)
289+
dycore_state = self._initialize_dycore_state(
290+
communicator, quantity_factory.backend
291+
)
294292
physics_state = PhysicsState.init_zeros(
295293
quantity_factory=quantity_factory,
296294
schemes=schemes,

tests/main/driver/test_safety_checks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_check_state_domain_only():
6565
"unknown",
6666
origin=(1, 1, 0),
6767
extent=(3, 3, 2),
68-
gt4py_backend="numpy",
68+
backend="numpy",
6969
)
7070
dycore_state = unittest.mock.MagicMock(u=u_quantity)
7171
checker.check_state(dycore_state)
@@ -83,7 +83,7 @@ def test_check_nan_value():
8383
"unknown",
8484
origin=(0, 0, 0),
8585
extent=(4, 4, 2),
86-
gt4py_backend="numpy",
86+
backend="numpy",
8787
)
8888
dycore_state = unittest.mock.MagicMock(u=u_quantity)
8989
with pytest.raises(RuntimeError):

0 commit comments

Comments
 (0)