Skip to content

Commit 22c14ab

Browse files
authored
Merge pull request #116 from romanc/romanc/replace-null-comm
refactor: replace `NullComm` with `LocalComm`
2 parents 3c2c3d4 + 85597e2 commit 22c14ab

File tree

4 files changed

+32
-36
lines changed

4 files changed

+32
-36
lines changed

examples/standalone/runfile/acoustics.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# type: ignore
33
from types import SimpleNamespace
4-
from typing import Any, Dict, List, Optional, Tuple
4+
from typing import Any, Dict, List, Tuple
55

66
import click
77
import f90nml
@@ -14,24 +14,20 @@
1414
CubedSphereCommunicator,
1515
CubedSpherePartitioner,
1616
DaceConfig,
17-
NullComm,
17+
LocalComm,
18+
MPIComm,
1819
StencilConfig,
1920
StencilFactory,
2021
TilePartitioner,
2122
)
23+
from ndsl.comm import Comm
2224
from ndsl.performance import Timer
2325
from ndsl.stencils.testing import Grid
2426
from pyfv3 import DynamicalCoreConfig
2527
from pyfv3.stencils import AcousticDynamics
2628
from pyfv3.testing import TranslateDynCore
2729

2830

29-
try:
30-
from mpi4py import MPI
31-
except ImportError:
32-
MPI = None
33-
34-
3531
def dycore_config_from_namelist(data_directory: str) -> DynamicalCoreConfig:
3632
"""
3733
Reads the namelist at the given directory and sets
@@ -89,17 +85,15 @@ def get_state_from_input(
8985
def set_up_communicator(
9086
disable_halo_exchange: bool,
9187
layout: Tuple[int, int],
92-
) -> Tuple[Optional[MPI.Comm], Optional[CubedSphereCommunicator]]:
93-
partitioner = CubedSpherePartitioner(TilePartitioner(layout))
94-
if MPI is not None:
95-
comm = MPI.COMM_WORLD
96-
else:
97-
comm = None
98-
if not disable_halo_exchange:
99-
assert comm is not None
100-
cube_comm = CubedSphereCommunicator(comm, partitioner)
101-
else:
102-
cube_comm = CubedSphereCommunicator(NullComm(0, 0), partitioner)
88+
) -> Tuple[Comm, CubedSphereCommunicator]:
89+
comm = (
90+
LocalComm(rank=0, total_ranks=1, buffer={})
91+
if disable_halo_exchange
92+
else MPIComm()
93+
)
94+
cube_comm = CubedSphereCommunicator(
95+
comm, CubedSpherePartitioner(TilePartitioner(layout))
96+
)
10397
return comm, cube_comm
10498

10599

examples/standalone/runfile/compile.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,12 @@
77

88
import f90nml
99
import gt4py.cartesian.config
10+
from mpi4py import MPI
1011

11-
from ndsl import NullComm
12+
from ndsl import LocalComm
1213
from pyfv3 import DynamicalCoreConfig
1314

1415

15-
try:
16-
from mpi4py import MPI
17-
except ImportError:
18-
MPI = None
19-
2016
local = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
2117
sys.path.insert(0, local)
2218
from runfile.dynamics import get_experiment_info, setup_dycore # noqa: E402
@@ -68,10 +64,8 @@ def parse_args() -> Namespace:
6864
for iteration in range(iterations):
6965
top_tile_rank = global_rank + size * iteration
7066
if top_tile_rank < sub_tiles:
71-
mpi_comm = NullComm(
72-
rank=top_tile_rank,
73-
total_ranks=6 * sub_tiles,
74-
fill_value=0.0,
67+
mpi_comm = LocalComm(
68+
rank=top_tile_rank, total_ranks=6 * sub_tiles, buffer_dict={}
7569
)
7670
gt4py.cartesian.config.cache_settings["dir_name"] = os.environ.get(
7771
"GT_CACHE_ROOT", f".gt_cache_{mpi_comm.Get_rank():06}"

examples/standalone/runfile/dynamics.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
CubedSphereCommunicator,
2323
CubedSpherePartitioner,
2424
DaceConfig,
25-
NullComm,
25+
LocalComm,
26+
MPIComm,
2627
StencilConfig,
2728
StencilFactory,
2829
TilePartitioner,
@@ -286,10 +287,15 @@ def setup_dycore(
286287
namelist = f90nml.read(args.data_dir + "/input.nml")
287288
dycore_config = DynamicalCoreConfig.from_f90nml(namelist)
288289
experiment_name, is_baroclinic_test_case = get_experiment_info(args.data_dir)
289-
if args.disable_halo_exchange:
290-
mpi_comm = NullComm(MPI.COMM_WORLD.Get_rank(), MPI.COMM_WORLD.Get_size())
291-
else:
292-
mpi_comm = MPI.COMM_WORLD
290+
mpi_comm = (
291+
LocalComm(
292+
rank=MPI.COMM_WORLD.Get_rank(),
293+
total_ranks=MPI.COMM_WORLD.Get_size(),
294+
buffer_dict={},
295+
)
296+
if args.disable_halo_exchange
297+
else MPIComm()
298+
)
293299
dycore, state, stencil_factory = setup_dycore(
294300
dycore_config,
295301
mpi_comm,

pyfv3/wrappers/geos_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
DaceConfig,
1818
DaCeOrchestration,
1919
GridIndexing,
20-
NullComm,
20+
LocalComm,
2121
PerformanceCollector,
2222
QuantityFactory,
2323
StencilConfig,
@@ -114,7 +114,9 @@ def __init__(
114114
# Look for an override to run on a single node
115115
gtfv3_single_rank_override = int(os.getenv("GTFV3_SINGLE_RANK_OVERRIDE", -1))
116116
if gtfv3_single_rank_override >= 0:
117-
comm = NullComm(gtfv3_single_rank_override, 6, 42)
117+
comm = LocalComm(
118+
rank=gtfv3_single_rank_override, total_ranks=6, buffer_dict={}
119+
)
118120

119121
# Make a custom performance collector for the GEOS wrapper
120122
self.perf_collector = PerformanceCollector("GEOS wrapper", comm)

0 commit comments

Comments
 (0)