Skip to content

Commit 516427e

Browse files
authored
refactor: move NullComm from NDSL to pace (#168)
We decided to remove `NullComm` from NDSL in favor of `MPIComm` and `LocalComm`, see NOAA-GFDL/NDSL#318 for context. pace exposes a `NullCommConfig` and is heavily relying on `NullComm` in tests. We thus suggest to move `NullComm` to `pace/comm.py`. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
1 parent 0aa69af commit 516427e

File tree

13 files changed

+141
-33
lines changed

13 files changed

+141
-33
lines changed

pace/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
CreatesComm,
55
CreatesCommSelector,
66
MPICommConfig,
7+
NullComm,
78
NullCommConfig,
89
ReaderCommConfig,
910
WriterCommConfig,
@@ -25,6 +26,7 @@
2526
__version__ = "0.2.0"
2627

2728
__all__ = [
29+
"NullComm",
2830
"CreatesComm",
2931
"CreatesCommSelector",
3032
"MPICommConfig",

pace/comm.py

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,126 @@
11
import abc
2+
import copy
23
import dataclasses
34
import os
4-
from typing import Any, ClassVar, List
5-
6-
from ndsl import MPIComm, NullComm
7-
from ndsl.comm import CachingCommReader, CachingCommWriter, Comm
5+
from typing import Any, ClassVar, List, Mapping, TypeVar, cast
6+
7+
from ndsl import MPIComm
8+
from ndsl.comm import (
9+
CachingCommReader,
10+
CachingCommWriter,
11+
Comm,
12+
ReductionOperator,
13+
Request,
14+
)
815
from pace.registry import Registry
916

1017

18+
T = TypeVar("T")
19+
20+
21+
class NullAsyncResult(Request):
22+
def __init__(self, recvbuf: Any = None) -> None:
23+
self._recvbuf = recvbuf
24+
25+
def wait(self) -> None:
26+
if self._recvbuf is not None:
27+
self._recvbuf[:] = 0.0
28+
29+
30+
class NullComm(Comm[T]):
31+
"""
32+
A class with a subset of the mpi4py Comm API, but which
33+
'receives' a fill value (default zero) instead of using MPI.
34+
"""
35+
36+
default_fill_value: T = cast(T, 0)
37+
38+
def __init__(self, rank: int, total_ranks: int, fill_value: T = default_fill_value):
39+
"""
40+
Args:
41+
rank: rank to mock
42+
total_ranks: number of total MPI ranks to mock
43+
fill_value: fill halos with this value when performing
44+
halo updates.
45+
"""
46+
self.rank = rank
47+
self.total_ranks = total_ranks
48+
self._fill_value = fill_value
49+
self._split_comms: Mapping[Any, list[NullComm]] = {}
50+
51+
def __repr__(self) -> str:
52+
return f"NullComm(rank={self.rank}, total_ranks={self.total_ranks})"
53+
54+
def Get_rank(self) -> int:
55+
return self.rank
56+
57+
def Get_size(self) -> int:
58+
return self.total_ranks
59+
60+
def bcast(self, value: T | None, root: int = 0) -> T | None:
61+
return value
62+
63+
def barrier(self) -> None:
64+
return
65+
66+
def Barrier(self) -> None:
67+
return
68+
69+
def Scatter(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
70+
if recvbuf is not None:
71+
recvbuf[:] = self._fill_value
72+
73+
def Gather(self, sendbuf, recvbuf, root: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
74+
if recvbuf is not None:
75+
recvbuf[:] = self._fill_value
76+
77+
def allgather(self, sendobj: T) -> list[T]:
78+
return [copy.deepcopy(sendobj) for _ in range(self.total_ranks)]
79+
80+
def Send(self, sendbuf, dest, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
81+
pass
82+
83+
def Isend(self, sendbuf, dest, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def]
84+
return NullAsyncResult()
85+
86+
def Recv(self, recvbuf, source, tag: int = 0, **kwargs: dict): # type: ignore[no-untyped-def]
87+
recvbuf[:] = self._fill_value
88+
89+
def Irecv(self, recvbuf, source, tag: int = 0, **kwargs: dict) -> Request: # type: ignore[no-untyped-def]
90+
return NullAsyncResult(recvbuf)
91+
92+
def sendrecv(self, sendbuf, dest, **kwargs: dict): # type: ignore[no-untyped-def]
93+
return sendbuf
94+
95+
def Split(self, color, key) -> Comm: # type: ignore[no-untyped-def]
96+
# key argument is ignored, assumes we're calling the ranks from least to
97+
# greatest when mocking Split
98+
self._split_comms[color] = self._split_comms.get(color, []) # type: ignore[index]
99+
rank = len(self._split_comms[color])
100+
total_ranks = rank + 1
101+
new_comm = NullComm(
102+
rank=rank, total_ranks=total_ranks, fill_value=self._fill_value
103+
)
104+
for comm in self._split_comms[color]:
105+
# won't know how many ranks there are until everything is split
106+
comm.total_ranks = total_ranks
107+
self._split_comms[color].append(new_comm)
108+
return new_comm
109+
110+
def allreduce(
111+
self, sendobj: T, op: ReductionOperator = ReductionOperator.NO_OP
112+
) -> T:
113+
return self._fill_value
114+
115+
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
116+
# TODO: what about reduction operator `op`?
117+
recvobj = sendobj
118+
return recvobj
119+
120+
def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T:
121+
raise NotImplementedError("NullComm.Allreduce_inplace")
122+
123+
11124
class CreatesComm(abc.ABC):
12125
"""
13126
Retrieves and does cleanup for a mpi4py-style Comm object.

tests/main/driver/test_driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import pytest
66

7-
from ndsl import NullComm, StencilConfig
7+
from ndsl import StencilConfig
88
from ndsl.performance.report import (
99
TimeReport,
1010
gather_hit_counts,
1111
gather_timing_data,
1212
get_sypd,
1313
)
14-
from pace import CreatesCommSelector, DriverConfig, NullCommConfig
14+
from pace import CreatesCommSelector, DriverConfig, NullComm, NullCommConfig
1515

1616

1717
def get_driver_config(

tests/main/driver/test_restart_fortran.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
CubedSphereCommunicator,
66
CubedSpherePartitioner,
77
LocalComm,
8-
NullComm,
98
QuantityFactory,
109
SubtileGridSizer,
1110
TilePartitioner,
1211
)
13-
from pace import FortranRestartInit, GeneratedGridConfig
12+
from pace import FortranRestartInit, GeneratedGridConfig, NullComm
1413
from pyshield import PHYSICS_PACKAGES
1514
from tests.paths import REPO_ROOT
1615

tests/main/driver/test_restart_serial.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@
1010
from ndsl import (
1111
CubedSphereCommunicator,
1212
CubedSpherePartitioner,
13-
NullComm,
1413
Quantity,
1514
QuantityFactory,
1615
SubtileGridSizer,
1716
TilePartitioner,
1817
)
1918
from pace import (
2019
AnalyticInit,
21-
CreatesComm,
2220
DriverConfig,
2321
GeneratedGridConfig,
22+
NullComm,
2423
RestartConfig,
2524
)
2625
from pyshield import PHYSICS_PACKAGES
@@ -30,21 +29,6 @@
3029
DIR = os.path.dirname(os.path.abspath(__file__))
3130

3231

33-
class NullCommConfig(CreatesComm):
34-
def __init__(self, layout):
35-
self.layout = layout
36-
37-
def get_comm(self):
38-
return NullComm(
39-
rank=0,
40-
total_ranks=6 * self.layout[0] * self.layout[1],
41-
fill_value=0.0,
42-
)
43-
44-
def cleanup(self, comm):
45-
pass
46-
47-
4832
def test_default_save_restart():
4933
restart_config = RestartConfig()
5034
assert restart_config.save_restart is False

tests/main/fv3core/test_cartesian_grid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import numpy as np
44
import pytest
55

6-
from ndsl import NullComm, TileCommunicator, TilePartitioner
6+
from ndsl import TileCommunicator, TilePartitioner
77
from ndsl.constants import PI
88
from ndsl.grid import MetricTerms
9+
from pace import NullComm
910

1011

1112
@pytest.mark.parametrize("npx", [8])

tests/main/fv3core/test_dycore_baroclinic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
CubedSpherePartitioner,
1515
DaceConfig,
1616
GridIndexing,
17-
NullComm,
1817
QuantityFactory,
1918
StencilConfig,
2019
StencilFactory,
@@ -24,6 +23,7 @@
2423
)
2524
from ndsl.grid import DampingCoefficients, GridData, MetricTerms
2625
from ndsl.performance.timer import NullTimer
26+
from pace import NullComm
2727
from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig
2828

2929

tests/main/fv3core/test_dycore_call.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
CubedSpherePartitioner,
1212
DaceConfig,
1313
GridIndexing,
14-
NullComm,
1514
Quantity,
1615
QuantityFactory,
1716
StencilConfig,
@@ -22,6 +21,7 @@
2221
from ndsl.grid import DampingCoefficients, GridData, MetricTerms
2322
from ndsl.performance.timer import NullTimer, Timer
2423
from ndsl.stencils.testing import assert_same_temporaries, copy_temporaries
24+
from pace import NullComm
2525
from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig
2626
from pyfv3.initialization.analytic_init import AnalyticCase
2727

tests/main/fv3core/test_init_from_geos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import pytest # noqa
44

5-
from ndsl import NullComm
5+
from pace import NullComm
66
from pyfv3 import DynamicalCore
77
from pyfv3.wrappers import GeosDycoreWrapper
88

tests/main/physics/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
DaceConfig,
1212
DaCeOrchestration,
1313
GridIndexing,
14-
NullComm,
1514
QuantityFactory,
1615
StencilConfig,
1716
StencilFactory,
@@ -20,6 +19,7 @@
2019
)
2120
from ndsl.grid import GridData, MetricTerms
2221
from ndsl.stencils.testing import assert_same_temporaries, copy_temporaries
22+
from pace import NullComm
2323
from pyshield import PHYSICS_PACKAGES, Physics, PhysicsConfig, PhysicsState
2424

2525

0 commit comments

Comments
 (0)