Skip to content

Commit 6afc136

Browse files
committed
we now have padded send/receive fields for 512 bit alignment
1 parent 9b55051 commit 6afc136

File tree

2 files changed

+148
-105
lines changed

2 files changed

+148
-105
lines changed

mpisppy/cylinders/spcommunicator.py

Lines changed: 87 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -30,61 +30,78 @@
3030

3131
logger = logging.getLogger(__name__)
3232

33-
def communicator_array(size):
34-
logical_size = size + 1
35-
padded_size = ((logical_size + 7) // 8) * 8
36-
37-
itemsize = np.dtype('d').itemsize
38-
mem = MPI.Alloc_mem(padded_size * itemsize)
39-
40-
full_arr = np.frombuffer(mem, dtype='d')
33+
def communicator_array(data_length: int):
34+
"""
35+
Allocate an MPI memory region with a padded length (multiple of 8 doubles = 64B),
36+
but expose a logical view of length (data_length + 1) where the last element is
37+
the read/write id.
38+
Returns:
39+
full_arr: padded array (used for SPWindow put/get)
40+
logical_arr: logical view (data + id), last element is id
41+
data_length: number of data entries (excluding id)
42+
logical_len: data_length + 1
43+
padded_len: multiple-of-8 length used in the MPI window
44+
"""
45+
logical_len = data_length + 1
46+
padded_len = ((logical_len + 7) // 8) * 8
47+
48+
itemsize = np.dtype("d").itemsize
49+
mem = MPI.Alloc_mem(padded_len * itemsize)
50+
51+
full_arr = np.frombuffer(mem, dtype="d", count=padded_len)
4152
full_arr[:] = np.nan
42-
43-
# Return the full array for internal use, but record the logical size
44-
arr = full_arr[:logical_size]
45-
arr[-1] = 0
46-
return arr, size # Note: size here is the 'field_length * buffer_size'
53+
54+
logical_arr = full_arr[:logical_len]
55+
logical_arr[-1] = 0.0
56+
57+
return full_arr, logical_arr, data_length, logical_len, padded_len
4758

4859

4960
class FieldArray:
5061
"""
51-
Notes: Buffer that tracks new/old state as well. Light-weight wrapper around a numpy array.
52-
53-
The intention here is that these are passive data holding classes. That is, other classes are
54-
expected to update the internal fields. The lone exception to this is the read/write id field.
55-
See the `SendArray` and `RecvArray` classes for how that field is updated.
62+
Wrapper around an MPI-allocated numpy buffer with:
63+
- a padded "window" array used for MPI RMA (Design A)
64+
- a logical view used by mpi-sppy code (data + id)
5665
"""
5766

5867
def __init__(self, length: int):
59-
# Store both the array (logical size + 1) and the original data length
60-
self._array, self._data_length = communicator_array(length)
68+
# length is the data length (excluding the id)
69+
(self._full_array,
70+
self._array,
71+
self._data_length,
72+
self._logical_len,
73+
self._padded_len) = communicator_array(length)
6174
self._id = 0
6275

63-
def value_array(self) -> np.typing.NDArray:
64-
""" Returns only the data portion, excluding the ID field and padding """
65-
return self._array[:self._data_length]
66-
67-
def __getitem__(self, key):
68-
# TODO: Should probably be hiding the read/write id field but there are many functions
69-
# that expect it to be there and being able to read it is not really a problem.
70-
np_array = self.array()
71-
return np_array[key]
76+
def window_array(self) -> np.typing.NDArray:
77+
"""Full padded array (used for SPWindow get/put)."""
78+
return self._full_array
7279

7380
def array(self) -> np.typing.NDArray:
74-
"""
75-
Returns the numpy array for the field data including the read id
76-
"""
81+
"""Logical array (data + id)."""
7782
return self._array
7883

7984
def value_array(self) -> np.typing.NDArray:
80-
"""
81-
Returns the numpy array for the field data without the read id
82-
"""
83-
return self._array[:self._logical_size]
85+
"""Data only (excludes id)."""
86+
return self._array[:self._data_length]
87+
88+
def padded_len(self) -> int:
89+
return self._padded_len
90+
91+
def logical_len(self) -> int:
92+
return self._logical_len
93+
94+
def data_len(self) -> int:
95+
return self._data_length
96+
97+
def __getitem__(self, key):
98+
# Preserve old behavior: indexing into the logical view.
99+
return self._array[key]
84100

85101
def id(self) -> int:
86102
return self._id
87103

104+
88105
class SendArray(FieldArray):
89106

90107
def __init__(self, length: int):
@@ -105,10 +122,6 @@ def _next_write_id(self) -> int:
105122
"""
106123
self._id += 1
107124
self._array[-1] = self._id
108-
# This ensures the update to the ID (the malloc-sensitive boundary)
109-
# is pushed to the window and synchronized before any other
110-
# part of the code (like Gurobi or another MPI rank) touches it.
111-
self.win.Flush(self.strata_rank)
112125
return self._id
113126

114127

@@ -155,7 +168,8 @@ def __init__(self, data: FieldArray, field_length: int, buffer_size: int):
155168

156169
def _get_value_array(self, read_write_index):
157170
position = read_write_index % self._buffer_size
158-
return self.data._array[(position*self._field_length):((position+1)*self._field_length)]
171+
arr = self.data.array() # logical view
172+
return arr[(position*self._field_length):((position+1)*self._field_length)]
159173

160174

161175
class SendCircularBuffer(_CircularBuffer):
@@ -260,13 +274,12 @@ def _split_key(self, key) -> tuple[Field, int]:
260274
"""
261275
return key
262276

263-
def _build_window_spec(self) -> dict[Field, int]:
264-
""" Build dict with fields and lengths needed for local MPI window
277+
def _build_window_spec(self) -> dict[Field, tuple[int, int]]:
278+
""" Build dict with fields and padded lengths needed for local MPI window
265279
"""
266280
window_spec = dict()
267-
for (field,buf) in self.send_buffers.items():
268-
window_spec[field] = np.size(buf.array())
269-
## End for
281+
for (field, buf) in self.send_buffers.items():
282+
window_spec[field] = (buf.logical_len(), buf.padded_len())
270283
return window_spec
271284

272285
def _create_field_rank_mappings(self) -> None:
@@ -277,7 +290,9 @@ def _create_field_rank_mappings(self) -> None:
277290
if rank == self.strata_rank:
278291
continue
279292
self.ranks_to_fields[rank] = []
280-
for field in buffer_layout:
293+
for field in buffer_layout.keys():
294+
if field == Field.WHOLE:
295+
continue
281296
if field not in self.fields_to_ranks:
282297
self.fields_to_ranks[field] = []
283298
self.fields_to_ranks[field].append(rank)
@@ -288,18 +303,23 @@ def _create_field_rank_mappings(self) -> None:
288303
def _validate_recv_field(self, field: Field, origin: int, length: int):
289304
remote_buffer_layout = self.window.strata_buffer_layouts[origin]
290305
if field not in remote_buffer_layout:
291-
raise RuntimeError(f"{self.__class__.__name__} on local {self.strata_rank=} "
292-
f"could not find {field=} on remote rank {origin} with "
293-
f"class {self.communicators[origin]['spcomm_class']}."
294-
)
295-
_, remote_length = remote_buffer_layout[field]
296-
if (length + 1) != remote_length:
297-
raise RuntimeError(f"{self.__class__.__name__} on local {self.strata_rank=} "
298-
f"{field=} has length {length} on local "
299-
f"{self.strata_rank=} and length {remote_length} "
300-
f"on remote rank {origin} with class "
301-
f"{self.communicators[origin]['spcomm_class']}."
302-
)
306+
raise RuntimeError(
307+
f"{self.__class__.__name__} on local {self.strata_rank=} "
308+
f"could not find {field=} on remote rank {origin} with "
309+
f"class {self.communicators[origin]['spcomm_class']}."
310+
)
311+
312+
_, remote_logical_len, remote_padded_len = remote_buffer_layout[field]
313+
expected_logical_len = length + 1
314+
expected_padded_len = ((expected_logical_len + 7) // 8) * 8
315+
316+
if remote_logical_len != expected_logical_len or remote_padded_len != expected_padded_len:
317+
raise RuntimeError(
318+
f"{self.__class__.__name__} on local {self.strata_rank=} "
319+
f"{field=} expects (logical={expected_logical_len}, padded={expected_padded_len}) "
320+
f"but remote rank {origin} advertises (logical={remote_logical_len}, padded={remote_padded_len}) "
321+
f"with class {self.communicators[origin]['spcomm_class']}."
322+
)
303323

304324
def register_recv_field(self, field: Field, origin: int, length: int = -1) -> RecvArray:
305325
# print(f"{self.__class__.__name__}.register_recv_field, {field=}, {origin=}")
@@ -308,7 +328,10 @@ def register_recv_field(self, field: Field, origin: int, length: int = -1) -> Re
308328
length = self._field_lengths[field]
309329
if key in self.receive_buffers:
310330
my_fa = self.receive_buffers[key]
311-
assert(length + 1 == np.size(my_fa.array()))
331+
expected_logical_len = length + 1
332+
expected_padded_len = ((expected_logical_len + 7) // 8) * 8
333+
assert expected_logical_len == my_fa.logical_len()
334+
assert expected_padded_len == my_fa.padded_len()
312335
else:
313336
self._validate_recv_field(field, origin, length)
314337
my_fa = RecvArray(length)
@@ -407,7 +430,7 @@ def register_receive_fields(self) -> None:
407430
if strata_rank == self.strata_rank:
408431
continue
409432
cls = comm["spcomm_class"]
410-
if field in self.ranks_to_fields[strata_rank]:
433+
if field != Field.WHOLE and field in self.ranks_to_fields[strata_rank]:
411434
buff = self.register_recv_field(field, strata_rank)
412435
self.receive_field_spcomms[field].append((strata_rank, cls, buff))
413436

@@ -419,7 +442,7 @@ def put_send_buffer(self, buf: SendArray, field: Field):
419442
This automatically updates handles the write id.
420443
"""
421444
buf._next_write_id()
422-
self.window.put(buf.array(), field)
445+
self.window.put(buf.window_array(), field)
423446
return
424447

425448
def get_receive_buffer(self,
@@ -449,9 +472,9 @@ def get_receive_buffer(self,
449472

450473
last_id = buf.id()
451474

452-
self.window.get(buf.array(), origin, field)
475+
self.window.get(buf.window_array(), origin, field) # padded view
453476

454-
new_id = int(buf.array()[-1])
477+
new_id = int(buf.array()[-1]) # logical view
455478

456479
if synchronize:
457480
local_val = np.array((new_id,), 'i')

0 commit comments

Comments
 (0)