Skip to content

Commit dcf79de

Browse files
committed
Intial impl of Halo operator
1 parent 19e873a commit dcf79de

File tree

2 files changed

+236
-0
lines changed

2 files changed

+236
-0
lines changed

examples/plot_halo.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#mpirun --oversubscribe -np 8 python3 plot_halo.py
2+
import math
3+
import numpy as np
4+
import pylops_mpi
5+
from pylops_mpi.basicoperators.Halo import MPIHalo, ScatterType, local_block_split, BoundaryType
6+
from mpi4py import MPI
7+
8+
9+
comm = MPI.COMM_WORLD
10+
rank = comm.Get_rank()
11+
size = comm.Get_size()
12+
13+
gdim = (8, 8, 8)
14+
p_prime = int(math.pow(size, 1/3))
15+
g_shape = (p_prime,p_prime,p_prime)
16+
17+
halo_op = MPIHalo(
18+
dims=gdim,
19+
halo=1,
20+
scatter=ScatterType.BLOCK,
21+
proc_grid_shape=g_shape,
22+
boundary_mode=BoundaryType.ZERO,
23+
comm=comm
24+
)
25+
26+
x_data = np.arange(np.prod(gdim)).astype(np.float64).reshape(gdim)
27+
x_slice = local_block_split(gdim, comm, g_shape)
28+
x_local = x_data[x_slice]
29+
x_dist = pylops_mpi.DistributedArray(global_shape=np.prod(gdim),
30+
local_shapes=comm.allgather(np.prod(x_local.shape)),
31+
base_comm=comm,
32+
partition=pylops_mpi.Partition.SCATTER)
33+
34+
x_dist.local_array[:] = x_local.flatten()
35+
36+
x_with_halo = halo_op @ x_dist
37+
print(rank, x_with_halo.local_array.reshape(gdim[0]//p_prime + 2, gdim[1]//p_prime + 2, gdim[2]//p_prime + 2))
38+
x_extracted = halo_op.H @ x_with_halo

pylops_mpi/basicoperators/Halo.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import math
2+
import numpy as np
3+
from enum import Enum
4+
from mpi4py import MPI
5+
6+
from pylops.utils.backend import get_module
7+
from pylops_mpi import MPILinearOperator, DistributedArray, Partition
8+
9+
10+
class ScatterType(Enum):
11+
BLOCK = "BLOCK"
12+
SLAB = "SLAB"
13+
STRIPE = "STRIPE"
14+
15+
class BoundaryType(Enum):
16+
ZERO = "ZERO"
17+
REFLECT = "REFLECT"
18+
PERIODIC = "PERIODIC"
19+
20+
def local_block_split(global_shape: tuple, comm, grid_shape: tuple = None) -> tuple:
21+
ndim = len(global_shape)
22+
size = comm.Get_size()
23+
# default: put all ranks on the last axis
24+
if grid_shape is None:
25+
grid_shape = (1,) * (ndim - 1) + (size,)
26+
if math.prod(grid_shape) != size:
27+
raise ValueError(f"grid_shape {grid_shape} does not match comm size {size}")
28+
29+
cart = comm.Create_cart(grid_shape, periods=[False] * ndim, reorder=True)
30+
coords = cart.Get_coords(cart.Get_rank())
31+
32+
slices = []
33+
for gdim, procs_on_axis, coord in zip(global_shape, grid_shape, coords):
34+
block_size = math.ceil(gdim / procs_on_axis)
35+
start = coord * block_size
36+
end = min(start + block_size, gdim)
37+
if coord == procs_on_axis - 1:
38+
sl = slice(start, None)
39+
else:
40+
sl = slice(start, end)
41+
slices.append(sl)
42+
return tuple(slices)
43+
44+
class MPIHalo(MPILinearOperator):
45+
def __init__(
46+
self,
47+
dims: tuple,
48+
halo,
49+
scatter: ScatterType = ScatterType.BLOCK,
50+
proc_grid_shape: tuple = None,
51+
comm: MPI.Comm = MPI.COMM_WORLD,
52+
boundary_mode: BoundaryType = BoundaryType.ZERO,
53+
dtype = np.float64
54+
):
55+
self.global_dims = tuple(dims)
56+
self.ndim = len(dims)
57+
self.halo = self._parse_halo(halo)
58+
self.scatter_type = scatter
59+
self.boundary_mode = boundary_mode
60+
self.comm = comm
61+
self.dtype = dtype
62+
63+
if self.scatter_type == ScatterType.BLOCK:
64+
proc_grid_shape = proc_grid_shape or tuple([1] * self.ndim)
65+
self.proc_grid_shape = tuple(proc_grid_shape)
66+
67+
self.cart_comm, self.neigh = self._build_topo()
68+
self.local_dims = self._calc_local_dims()
69+
self.local_extent = self._calc_local_extent()
70+
gext = self._calc_global_extent()
71+
self.shape = (int(np.prod(gext)), int(np.prod(self.global_dims)))
72+
super().__init__(shape=self.shape, dtype=np.dtype(dtype), base_comm=comm)
73+
74+
75+
def _parse_halo(self, h):
76+
if isinstance(h, int): return (h,) * (2 * self.ndim)
77+
h = tuple(h)
78+
if len(h) == 1 : return h * (2 * self.ndim)
79+
if len(h) == self.ndim : return sum(tuple([(d, d) for d in h]), ())
80+
if len(h) == 2 * self.ndim : return h
81+
raise ValueError(f"Invalid halo length {len(h)} for ndim={self.ndim}")
82+
83+
def _build_topo(self):
84+
periods = [self.boundary_mode == BoundaryType.PERIODIC] * self.ndim
85+
cart_comm = self.comm.Create_cart(self.proc_grid_shape, periods=periods, reorder=True)
86+
neigh = {}
87+
for ax in range(self.ndim):
88+
before, after = cart_comm.Shift(ax, 1)
89+
neigh[('-', ax)] = before
90+
neigh[('+', ax)] = after
91+
return cart_comm, neigh
92+
93+
def _calc_local_dims(self):
94+
rank = self.cart_comm.Get_rank()
95+
coords = self.cart_comm.Get_coords(rank)
96+
local = []
97+
for ax, (gdim, coord, grid_procs) in enumerate(zip(self.global_dims, coords, self.proc_grid_shape)):
98+
block_size = math.ceil(gdim / grid_procs)
99+
start = coord * block_size
100+
end = min(start + block_size, gdim)
101+
local.append(end - start)
102+
return tuple(local)
103+
104+
def _calc_local_extent(self):
105+
ext = []
106+
for ax in range(self.ndim):
107+
minus_halo, plus_halo = self.halo[2 * ax], self.halo[2 * ax + 1]
108+
ext.append(self.local_dims[ax] + minus_halo + plus_halo)
109+
return tuple(ext)
110+
111+
def _calc_global_extent(self):
112+
ext = []
113+
for ax, gdim in enumerate(self.global_dims):
114+
minus_halo, plus_halo = self.halo[2 * ax], self.halo[2 * ax + 1]
115+
ext.append(gdim + self.proc_grid_shape[ax] * (minus_halo + plus_halo))
116+
return tuple(ext)
117+
118+
119+
def _apply_bc_along_axis(self, ncp, arr, axis):
120+
minus_halo, plus_halo = self.halo[2 * axis], self.halo[2 * axis + 1]
121+
slicer = [slice(None)]*self.ndim
122+
minus_nbr, plus_nbr = self.neigh[('-',axis)], self.neigh[('+',axis)]
123+
# before
124+
if minus_halo and minus_nbr == MPI.PROC_NULL:
125+
slicer_minus = slicer.copy();
126+
slicer_minus[axis] = slice(0, minus_halo)
127+
if self.boundary_mode == BoundaryType.ZERO:
128+
arr[tuple(slicer_minus)] = 0
129+
elif self.boundary_mode == BoundaryType.REFLECT:
130+
slicer_core_m = slicer.copy()
131+
slicer_core_m[axis] = slice(minus_halo, 2 * minus_halo)
132+
arr[tuple(slicer_minus)] = ncp.flip(arr[tuple(slicer_core_m)], axis=axis)
133+
# after
134+
if plus_halo and plus_nbr == MPI.PROC_NULL:
135+
slicer_plus = slicer.copy(); slicer_plus[axis] = slice(-plus_halo, None)
136+
if self.boundary_mode == BoundaryType.ZERO:
137+
arr[tuple(slicer_plus)] = 0
138+
elif self.boundary_mode == BoundaryType.REFLECT:
139+
slicer_core_p = slicer.copy();
140+
slicer_core_p[axis] = slice(-2 * plus_halo, -plus_halo)
141+
arr[tuple(slicer_plus)] = ncp.flip(arr[tuple(slicer_core_p)], axis=axis)
142+
143+
def _exchange_along_axis(self, ncp, arr, axis, before, after):
144+
minus_nbr,plus_nbr = self.neigh[('-', axis)], self.neigh[('+', axis)]
145+
# slice definitions
146+
slicer = [slice(None)] * self.ndim
147+
# send before
148+
if before and minus_nbr != MPI.PROC_NULL:
149+
snd_s = slicer.copy(); snd_s[axis] = slice(before, 2 * before)
150+
snd = arr[tuple(snd_s)].copy()
151+
rcv = ncp.empty_like(snd)
152+
self.cart_comm.Sendrecv(snd, dest=minus_nbr, recvbuf=rcv, source=minus_nbr)
153+
rcv_s = slicer.copy(); rcv_s[axis] = slice(0, before)
154+
arr[tuple(rcv_s)] = rcv
155+
# send after
156+
if after and plus_nbr != MPI.PROC_NULL:
157+
snd_s = slicer.copy(); snd_s[axis] = slice(-2 * after, -after)
158+
rcv_s = slicer.copy(); rcv_s[axis] = slice(-after, None)
159+
snd = arr[tuple(snd_s)].copy()
160+
rcv = ncp.empty_like(snd)
161+
self.cart_comm.Sendrecv(snd, dest=plus_nbr, recvbuf=rcv, source=plus_nbr)
162+
arr[tuple(rcv_s)] = rcv
163+
164+
def _matvec(self, x):
165+
ncp = get_module(x.engine)
166+
if x.partition != Partition.SCATTER:
167+
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
168+
169+
core = x.local_array.reshape(self.local_dims)
170+
halo_arr = ncp.zeros(self.local_extent, dtype=self.dtype)
171+
# insert core
172+
core_slices = [slice(left, left + ldim) for left, ldim in zip(self.halo[::2], self.local_dims)]
173+
halo_arr[tuple(core_slices)] = core
174+
175+
# exchange along each axis
176+
for ax in range(self.ndim):
177+
before, after = self.halo[2 * ax], self.halo[2 * ax + 1]
178+
self._exchange_along_axis(ncp, halo_arr, axis=ax, before=before, after=after)
179+
180+
# apply BCs
181+
for ax in range(self.ndim):
182+
self._apply_bc_along_axis(ncp, halo_arr, axis=ax)
183+
# pack result
184+
res = DistributedArray(global_shape=self.shape[0],
185+
partition=Partition.SCATTER)
186+
res[:] = halo_arr.ravel()
187+
return res
188+
189+
def _rmatvec(self, x):
190+
if x.partition != Partition.SCATTER:
191+
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
192+
res = DistributedArray(global_shape=self.shape[1],
193+
partition=Partition.SCATTER)
194+
arr = x.local_array.reshape(self.local_extent)
195+
core_slices = [slice(left, left + ldim) for left, ldim in zip(self.halo[::2], self.local_dims)]
196+
core = arr[tuple(core_slices)]
197+
res[:] = core.ravel()
198+
return res

0 commit comments

Comments
 (0)