1+ import math
2+ import numpy as np
3+ from enum import Enum
4+ from mpi4py import MPI
5+
6+ from pylops .utils .backend import get_module
7+ from pylops_mpi import MPILinearOperator , DistributedArray , Partition
8+
9+
10+ class ScatterType (Enum ):
11+ BLOCK = "BLOCK"
12+ SLAB = "SLAB"
13+ STRIPE = "STRIPE"
14+
15+ class BoundaryType (Enum ):
16+ ZERO = "ZERO"
17+ REFLECT = "REFLECT"
18+ PERIODIC = "PERIODIC"
19+
20+ def local_block_split (global_shape : tuple , comm , grid_shape : tuple = None ) -> tuple :
21+ ndim = len (global_shape )
22+ size = comm .Get_size ()
23+ # default: put all ranks on the last axis
24+ if grid_shape is None :
25+ grid_shape = (1 ,) * (ndim - 1 ) + (size ,)
26+ if math .prod (grid_shape ) != size :
27+ raise ValueError (f"grid_shape { grid_shape } does not match comm size { size } " )
28+
29+ cart = comm .Create_cart (grid_shape , periods = [False ] * ndim , reorder = True )
30+ coords = cart .Get_coords (cart .Get_rank ())
31+
32+ slices = []
33+ for gdim , procs_on_axis , coord in zip (global_shape , grid_shape , coords ):
34+ block_size = math .ceil (gdim / procs_on_axis )
35+ start = coord * block_size
36+ end = min (start + block_size , gdim )
37+ if coord == procs_on_axis - 1 :
38+ sl = slice (start , None )
39+ else :
40+ sl = slice (start , end )
41+ slices .append (sl )
42+ return tuple (slices )
43+
44+ class MPIHalo (MPILinearOperator ):
45+ def __init__ (
46+ self ,
47+ dims : tuple ,
48+ halo ,
49+ scatter : ScatterType = ScatterType .BLOCK ,
50+ proc_grid_shape : tuple = None ,
51+ comm : MPI .Comm = MPI .COMM_WORLD ,
52+ boundary_mode : BoundaryType = BoundaryType .ZERO ,
53+ dtype = np .float64
54+ ):
55+ self .global_dims = tuple (dims )
56+ self .ndim = len (dims )
57+ self .halo = self ._parse_halo (halo )
58+ self .scatter_type = scatter
59+ self .boundary_mode = boundary_mode
60+ self .comm = comm
61+ self .dtype = dtype
62+
63+ if self .scatter_type == ScatterType .BLOCK :
64+ proc_grid_shape = proc_grid_shape or tuple ([1 ] * self .ndim )
65+ self .proc_grid_shape = tuple (proc_grid_shape )
66+
67+ self .cart_comm , self .neigh = self ._build_topo ()
68+ self .local_dims = self ._calc_local_dims ()
69+ self .local_extent = self ._calc_local_extent ()
70+ gext = self ._calc_global_extent ()
71+ self .shape = (int (np .prod (gext )), int (np .prod (self .global_dims )))
72+ super ().__init__ (shape = self .shape , dtype = np .dtype (dtype ), base_comm = comm )
73+
74+
75+ def _parse_halo (self , h ):
76+ if isinstance (h , int ): return (h ,) * (2 * self .ndim )
77+ h = tuple (h )
78+ if len (h ) == 1 : return h * (2 * self .ndim )
79+ if len (h ) == self .ndim : return sum (tuple ([(d , d ) for d in h ]), ())
80+ if len (h ) == 2 * self .ndim : return h
81+ raise ValueError (f"Invalid halo length { len (h )} for ndim={ self .ndim } " )
82+
83+ def _build_topo (self ):
84+ periods = [self .boundary_mode == BoundaryType .PERIODIC ] * self .ndim
85+ cart_comm = self .comm .Create_cart (self .proc_grid_shape , periods = periods , reorder = True )
86+ neigh = {}
87+ for ax in range (self .ndim ):
88+ before , after = cart_comm .Shift (ax , 1 )
89+ neigh [('-' , ax )] = before
90+ neigh [('+' , ax )] = after
91+ return cart_comm , neigh
92+
93+ def _calc_local_dims (self ):
94+ rank = self .cart_comm .Get_rank ()
95+ coords = self .cart_comm .Get_coords (rank )
96+ local = []
97+ for ax , (gdim , coord , grid_procs ) in enumerate (zip (self .global_dims , coords , self .proc_grid_shape )):
98+ block_size = math .ceil (gdim / grid_procs )
99+ start = coord * block_size
100+ end = min (start + block_size , gdim )
101+ local .append (end - start )
102+ return tuple (local )
103+
104+ def _calc_local_extent (self ):
105+ ext = []
106+ for ax in range (self .ndim ):
107+ minus_halo , plus_halo = self .halo [2 * ax ], self .halo [2 * ax + 1 ]
108+ ext .append (self .local_dims [ax ] + minus_halo + plus_halo )
109+ return tuple (ext )
110+
111+ def _calc_global_extent (self ):
112+ ext = []
113+ for ax , gdim in enumerate (self .global_dims ):
114+ minus_halo , plus_halo = self .halo [2 * ax ], self .halo [2 * ax + 1 ]
115+ ext .append (gdim + self .proc_grid_shape [ax ] * (minus_halo + plus_halo ))
116+ return tuple (ext )
117+
118+
119+ def _apply_bc_along_axis (self , ncp , arr , axis ):
120+ minus_halo , plus_halo = self .halo [2 * axis ], self .halo [2 * axis + 1 ]
121+ slicer = [slice (None )]* self .ndim
122+ minus_nbr , plus_nbr = self .neigh [('-' ,axis )], self .neigh [('+' ,axis )]
123+ # before
124+ if minus_halo and minus_nbr == MPI .PROC_NULL :
125+ slicer_minus = slicer .copy ();
126+ slicer_minus [axis ] = slice (0 , minus_halo )
127+ if self .boundary_mode == BoundaryType .ZERO :
128+ arr [tuple (slicer_minus )] = 0
129+ elif self .boundary_mode == BoundaryType .REFLECT :
130+ slicer_core_m = slicer .copy ()
131+ slicer_core_m [axis ] = slice (minus_halo , 2 * minus_halo )
132+ arr [tuple (slicer_minus )] = ncp .flip (arr [tuple (slicer_core_m )], axis = axis )
133+ # after
134+ if plus_halo and plus_nbr == MPI .PROC_NULL :
135+ slicer_plus = slicer .copy (); slicer_plus [axis ] = slice (- plus_halo , None )
136+ if self .boundary_mode == BoundaryType .ZERO :
137+ arr [tuple (slicer_plus )] = 0
138+ elif self .boundary_mode == BoundaryType .REFLECT :
139+ slicer_core_p = slicer .copy ();
140+ slicer_core_p [axis ] = slice (- 2 * plus_halo , - plus_halo )
141+ arr [tuple (slicer_plus )] = ncp .flip (arr [tuple (slicer_core_p )], axis = axis )
142+
143+ def _exchange_along_axis (self , ncp , arr , axis , before , after ):
144+ minus_nbr ,plus_nbr = self .neigh [('-' , axis )], self .neigh [('+' , axis )]
145+ # slice definitions
146+ slicer = [slice (None )] * self .ndim
147+ # send before
148+ if before and minus_nbr != MPI .PROC_NULL :
149+ snd_s = slicer .copy (); snd_s [axis ] = slice (before , 2 * before )
150+ snd = arr [tuple (snd_s )].copy ()
151+ rcv = ncp .empty_like (snd )
152+ self .cart_comm .Sendrecv (snd , dest = minus_nbr , recvbuf = rcv , source = minus_nbr )
153+ rcv_s = slicer .copy (); rcv_s [axis ] = slice (0 , before )
154+ arr [tuple (rcv_s )] = rcv
155+ # send after
156+ if after and plus_nbr != MPI .PROC_NULL :
157+ snd_s = slicer .copy (); snd_s [axis ] = slice (- 2 * after , - after )
158+ rcv_s = slicer .copy (); rcv_s [axis ] = slice (- after , None )
159+ snd = arr [tuple (snd_s )].copy ()
160+ rcv = ncp .empty_like (snd )
161+ self .cart_comm .Sendrecv (snd , dest = plus_nbr , recvbuf = rcv , source = plus_nbr )
162+ arr [tuple (rcv_s )] = rcv
163+
164+ def _matvec (self , x ):
165+ ncp = get_module (x .engine )
166+ if x .partition != Partition .SCATTER :
167+ raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
168+
169+ core = x .local_array .reshape (self .local_dims )
170+ halo_arr = ncp .zeros (self .local_extent , dtype = self .dtype )
171+ # insert core
172+ core_slices = [slice (left , left + ldim ) for left , ldim in zip (self .halo [::2 ], self .local_dims )]
173+ halo_arr [tuple (core_slices )] = core
174+
175+ # exchange along each axis
176+ for ax in range (self .ndim ):
177+ before , after = self .halo [2 * ax ], self .halo [2 * ax + 1 ]
178+ self ._exchange_along_axis (ncp , halo_arr , axis = ax , before = before , after = after )
179+
180+ # apply BCs
181+ for ax in range (self .ndim ):
182+ self ._apply_bc_along_axis (ncp , halo_arr , axis = ax )
183+ # pack result
184+ res = DistributedArray (global_shape = self .shape [0 ],
185+ partition = Partition .SCATTER )
186+ res [:] = halo_arr .ravel ()
187+ return res
188+
189+ def _rmatvec (self , x ):
190+ if x .partition != Partition .SCATTER :
191+ raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
192+ res = DistributedArray (global_shape = self .shape [1 ],
193+ partition = Partition .SCATTER )
194+ arr = x .local_array .reshape (self .local_extent )
195+ core_slices = [slice (left , left + ldim ) for left , ldim in zip (self .halo [::2 ], self .local_dims )]
196+ core = arr [tuple (core_slices )]
197+ res [:] = core .ravel ()
198+ return res
0 commit comments