Skip to content

Commit ebc4f96

Browse files
authored
Namelist Refactor: Utility functions for namelist to dict + conftest --no_legacy_namelist (NOAA-GFDL#246)
* NDSL Issue#64: - Adding helper functions for loading f90nml.Namelist from file and converting to dict. - Modifying conftest.py and ParallelTranslate layout() to rely more on f90nml.Namelist. - Adding temporary 'f90nml_namelist_only' flag to conftest.py to toggle testing using 1. only f90nml.Namelist, or 2. f90nml.Namelist and ndsl.Namelist (Changes will eventually facilitate removal of current ndsl.Namelist class) * Tagging with Issue#64 where appropriate * Changing path str to libpath.Path * Fixes from PR#246 feedback * Changing --legacy_namelist_support to --no_legacy_namelist (PR NOAA-GFDL#246) * linting * type hint fix * linting
1 parent a8b44ff commit ebc4f96

File tree

4 files changed

+194
-43
lines changed

4 files changed

+194
-43
lines changed

ndsl/stencils/testing/conftest.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import os
22
import re
3+
from pathlib import Path
34
from typing import Optional, Tuple
45

5-
import f90nml
66
import pytest
77
import xarray as xr
88
import yaml
9+
from f90nml import Namelist
910

1011
from ndsl import CompilationConfig, StencilConfig, StencilFactory
1112
from ndsl.comm.communicator import (
@@ -16,11 +17,14 @@
1617
from ndsl.comm.mpi import MPI, MPIComm
1718
from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner
1819
from ndsl.dsl.dace.dace_config import DaceConfig
19-
from ndsl.namelist import Namelist
20+
21+
# TODO: Remove NdslNamelist import after Issue#64 is resolved.
22+
from ndsl.namelist import Namelist as NdslNamelist
2023
from ndsl.stencils.testing.grid import Grid # type: ignore
2124
from ndsl.stencils.testing.parallel_translate import ParallelTranslate
2225
from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict
2326
from ndsl.stencils.testing.translate import TranslateGrid
27+
from ndsl.utils import grid_params_from_f90nml, load_f90nml
2428

2529

2630
def pytest_addoption(parser):
@@ -73,6 +77,12 @@ def pytest_addoption(parser):
7377
default=1,
7478
help="How many indices of failures to print from worst to best. Default to 1.",
7579
)
80+
parser.addoption(
81+
"--no_legacy_namelist",
82+
action="store_true",
83+
default=False,
84+
help="Removes support for `ndsl.Namelist` in translate tests (which we are trying to get rid off, see NDSL issue #64). Defaults to False.",
85+
)
7686
parser.addoption(
7787
"--grid",
7888
action="store",
@@ -124,9 +134,9 @@ def data_path(pytestconfig):
124134
return data_path_and_namelist_filename_from_config(pytestconfig)
125135

126136

127-
def data_path_and_namelist_filename_from_config(config) -> Tuple[str, str]:
128-
data_path = config.getoption("data_path")
129-
namelist_filename = os.path.join(data_path, "input.nml")
137+
def data_path_and_namelist_filename_from_config(config) -> Tuple[Path, Path]:
138+
data_path = Path(config.getoption("data_path"))
139+
namelist_filename = data_path / "input.nml"
130140
return data_path, namelist_filename
131141

132142

@@ -224,10 +234,6 @@ def get_savepoint_restriction(metafunc):
224234
return int(svpt) if svpt else None
225235

226236

227-
def get_namelist(namelist_filename):
228-
return Namelist.from_f90nml(f90nml.read(namelist_filename))
229-
230-
231237
def get_config(backend: str, communicator: Optional[Communicator]):
232238
stencil_config = StencilConfig(
233239
compilation_config=CompilationConfig(
@@ -243,14 +249,19 @@ def get_config(backend: str, communicator: Optional[Communicator]):
243249

244250
def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backend: str):
245251
savepoint_names = get_sequential_savepoint_names(metafunc, data_path)
246-
namelist = get_namelist(namelist_filename)
252+
namelist = load_f90nml(namelist_filename)
253+
grid_params = grid_params_from_f90nml(namelist)
247254
stencil_config = get_config(backend, None)
248-
ranks = get_ranks(metafunc, namelist.layout)
255+
ranks = get_ranks(metafunc, grid_params["layout"])
249256
savepoint_to_replay = get_savepoint_restriction(metafunc)
250257
grid_mode = metafunc.config.getoption("grid")
251258
topology_mode = metafunc.config.getoption("topology")
252259
sort_report = metafunc.config.getoption("sort_report")
253260
no_report = metafunc.config.getoption("no_report")
261+
262+
# Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone.
263+
no_legacy_namelist = metafunc.config.getoption("no_legacy_namelist")
264+
254265
return _savepoint_cases(
255266
savepoint_names,
256267
ranks,
@@ -263,6 +274,7 @@ def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backen
263274
topology_mode,
264275
sort_report=sort_report,
265276
no_report=no_report,
277+
no_legacy_namelist=no_legacy_namelist, # Issue#64: tmp flag
266278
)
267279

268280

@@ -273,36 +285,38 @@ def _savepoint_cases(
273285
stencil_config,
274286
namelist: Namelist,
275287
backend: str,
276-
data_path: str,
288+
data_path: Path,
277289
grid_mode: str,
278290
topology_mode: bool,
279291
sort_report: str,
280292
no_report: bool,
293+
no_legacy_namelist: bool, # Issue#64: tmp flag
281294
):
295+
grid_params = grid_params_from_f90nml(namelist)
282296
return_list = []
283297
for rank in ranks:
284298
if grid_mode == "default":
285299
grid = Grid._make(
286-
namelist.npx,
287-
namelist.npy,
288-
namelist.npz,
289-
namelist.layout,
300+
grid_params["npx"],
301+
grid_params["npy"],
302+
grid_params["npz"],
303+
grid_params["layout"],
290304
rank,
291305
backend,
292306
)
293307
elif grid_mode == "file" or grid_mode == "compute":
294-
ds_grid: xr.Dataset = xr.open_dataset(
295-
os.path.join(data_path, "Grid-Info.nc")
296-
).isel(savepoint=0)
308+
ds_grid: xr.Dataset = xr.open_dataset(data_path / "Grid-Info.nc").isel(
309+
savepoint=0
310+
)
297311
grid = TranslateGrid(
298312
dataset_to_dict(ds_grid.isel(rank=rank)),
299313
rank=rank,
300-
layout=namelist.layout,
314+
layout=grid_params["layout"],
301315
backend=backend,
302316
).python_grid()
303317
if grid_mode == "compute":
304318
compute_grid_data(
305-
grid, namelist, backend, namelist.layout, topology_mode
319+
grid, grid_params, backend, grid_params["layout"], topology_mode
306320
)
307321
else:
308322
raise NotImplementedError(f"Grid mode {grid_mode} is unknown.")
@@ -312,12 +326,18 @@ def _savepoint_cases(
312326
grid_indexing=grid.grid_indexing,
313327
)
314328
for test_name in sorted(list(savepoint_names)):
329+
# Temporary check (Issue#64): TODO Remove check and conversion from
330+
# f90nml.Namelist to ndsl.Namelist after ndsl.Namelist is removed
331+
if not no_legacy_namelist: # This means we use NdslNamelist.
332+
if not isinstance(namelist, NdslNamelist):
333+
namelist = NdslNamelist.from_f90nml(namelist)
334+
315335
testobj = get_test_class_instance(
316336
test_name, grid, namelist, stencil_factory
317337
)
318-
n_calls = xr.open_dataset(
319-
os.path.join(data_path, f"{test_name}-In.nc")
320-
).sizes["savepoint"]
338+
n_calls = xr.open_dataset(data_path / f"{test_name}-In.nc").sizes[
339+
"savepoint"
340+
]
321341
if savepoint_to_replay is not None:
322342
savepoint_iterator = range(savepoint_to_replay, savepoint_to_replay + 1)
323343
else:
@@ -337,11 +357,11 @@ def _savepoint_cases(
337357
return return_list
338358

339359

340-
def compute_grid_data(grid, namelist, backend, layout, topology_mode):
360+
def compute_grid_data(grid, grid_params, backend, layout, topology_mode):
341361
grid.make_grid_data(
342-
npx=namelist.npx,
343-
npy=namelist.npy,
344-
npz=namelist.npz,
362+
npx=grid_params["npx"],
363+
npy=grid_params["npy"],
364+
npz=grid_params["npz"],
345365
communicator=get_communicator(MPIComm(), layout, topology_mode),
346366
backend=backend,
347367
)
@@ -350,15 +370,20 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode):
350370
def parallel_savepoint_cases(
351371
metafunc, data_path, namelist_filename, mpi_rank, *, backend: str, comm
352372
):
353-
namelist = get_namelist(namelist_filename)
373+
namelist = load_f90nml(namelist_filename)
374+
grid_params = grid_params_from_f90nml(namelist)
354375
topology_mode = metafunc.config.getoption("topology")
355376
sort_report = metafunc.config.getoption("sort_report")
356377
no_report = metafunc.config.getoption("no_report")
357-
communicator = get_communicator(comm, namelist.layout, topology_mode)
378+
communicator = get_communicator(comm, grid_params["layout"], topology_mode)
358379
stencil_config = get_config(backend, communicator)
359380
savepoint_names = get_parallel_savepoint_names(metafunc, data_path)
360381
grid_mode = metafunc.config.getoption("grid")
361382
savepoint_to_replay = get_savepoint_restriction(metafunc)
383+
384+
# Temporary flag (Issue#64): TODO Remove once ndsl.Namelist is gone.
385+
no_legacy_namelist = metafunc.config.getoption("no_legacy_namelist")
386+
362387
return _savepoint_cases(
363388
savepoint_names,
364389
[mpi_rank],
@@ -371,6 +396,7 @@ def parallel_savepoint_cases(
371396
topology_mode,
372397
sort_report=sort_report,
373398
no_report=no_report,
399+
no_legacy_namelist=no_legacy_namelist, # Issue#64: tmp flag
374400
)
375401

376402

@@ -388,7 +414,10 @@ def generate_sequential_stencil_tests(metafunc, *, backend: str):
388414
metafunc.config
389415
)
390416
savepoint_cases = sequential_savepoint_cases(
391-
metafunc, data_path, namelist_filename, backend=backend
417+
metafunc,
418+
data_path,
419+
namelist_filename,
420+
backend=backend,
392421
)
393422
metafunc.parametrize(
394423
"case", savepoint_cases, ids=[str(item) for item in savepoint_cases]

ndsl/stencils/testing/parallel_translate.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77

88
from ndsl.constants import HORIZONTAL_DIMS, N_HALO_DEFAULT, X_DIMS, Y_DIMS
99
from ndsl.dsl import gt4py_utils as utils
10+
11+
# TODO: Remove once ndsl.Namelist is gone (Issue#64)
12+
from ndsl.namelist import Namelist as NdslNamelist
1013
from ndsl.quantity import Quantity
1114
from ndsl.stencils.testing.translate import (
1215
TranslateFortranData2Py,
1316
read_serialized_data,
1417
)
18+
from ndsl.utils import grid_params_from_f90nml
1519

1620

1721
class ParallelTranslate:
@@ -129,7 +133,14 @@ def rank_grids(self):
129133

130134
@property
131135
def layout(self):
132-
return self.namelist.layout
136+
# TODO: Once ndsl.namelist.Namelist is gone (Issue#64),
137+
# remove this check in favor of f90nml.namelist.Namelist
138+
if isinstance(self.namelist, NdslNamelist):
139+
return self.namelist.layout
140+
141+
# Assumption: namelist is f90nml.namelist.Namelist
142+
grid_params = grid_params_from_f90nml(self.namelist)
143+
return grid_params["layout"]
133144

134145
def compute_sequential(self, inputs_list, communicator_list):
135146
"""Compute the outputs while iterating over a set of communicator

ndsl/stencils/testing/savepoint.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import dataclasses
2-
import os
2+
from pathlib import Path
33
from typing import Dict, Protocol, Union
44

55
import numpy as np
@@ -22,7 +22,7 @@ def _process_if_scalar(value: np.ndarray) -> Union[np.ndarray, float, int]:
2222

2323

2424
class DataLoader:
25-
def __init__(self, rank: int, data_path: str):
25+
def __init__(self, rank: int, data_path: Path):
2626
self._data_path = data_path
2727
self._rank = rank
2828

@@ -33,7 +33,7 @@ def load(
3333
i_call: int = 0,
3434
) -> Dict[str, Union[np.ndarray, float, int]]:
3535
return dataset_to_dict(
36-
xr.open_dataset(os.path.join(self._data_path, f"{name}{postfix}.nc"))
36+
xr.open_dataset(self._data_path / f"{name}{postfix}.nc")
3737
.isel(rank=self._rank)
3838
.isel(savepoint=i_call)
3939
)
@@ -54,7 +54,7 @@ class SavepointCase:
5454
"""
5555

5656
savepoint_name: str
57-
data_dir: str
57+
data_dir: Path
5858
i_call: int
5959
testobj: Translate
6060
grid: Grid
@@ -67,26 +67,24 @@ def __str__(self):
6767
@property
6868
def exists(self) -> bool:
6969
return (
70-
xr.open_dataset(
71-
os.path.join(self.data_dir, f"{self.savepoint_name}-In.nc")
72-
).sizes["rank"]
70+
xr.open_dataset(self.data_dir / f"{self.savepoint_name}-In.nc").sizes[
71+
"rank"
72+
]
7373
> self.grid.rank
7474
)
7575

7676
@property
7777
def ds_in(self) -> xr.Dataset:
7878
return (
79-
xr.open_dataset(os.path.join(self.data_dir, f"{self.savepoint_name}-In.nc"))
79+
xr.open_dataset(self.data_dir / f"{self.savepoint_name}-In.nc")
8080
.isel(rank=self.grid.rank)
8181
.isel(savepoint=self.i_call)
8282
)
8383

8484
@property
8585
def ds_out(self) -> xr.Dataset:
8686
return (
87-
xr.open_dataset(
88-
os.path.join(self.data_dir, f"{self.savepoint_name}-Out.nc")
89-
)
87+
xr.open_dataset(self.data_dir / f"{self.savepoint_name}-Out.nc")
9088
.isel(rank=self.grid.rank)
9189
.isel(savepoint=self.i_call)
9290
)

0 commit comments

Comments
 (0)