3030
3131logger = logging .getLogger (__name__ )
3232
33- def communicator_array (size ):
34- logical_size = size + 1
35- padded_size = ((logical_size + 7 ) // 8 ) * 8
36-
37- itemsize = np .dtype ('d' ).itemsize
38- mem = MPI .Alloc_mem (padded_size * itemsize )
39-
40- full_arr = np .frombuffer (mem , dtype = 'd' )
33+ def communicator_array (data_length : int ):
34+ """
35+ Allocate an MPI memory region with a padded length (multiple of 8 doubles = 64B),
36+ but expose a logical view of length (data_length + 1) where the last element is
37+ the read/write id.
38+ Returns:
39+ full_arr: padded array (used for SPWindow put/get)
40+ logical_arr: logical view (data + id), last element is id
41+ data_length: number of data entries (excluding id)
42+ logical_len: data_length + 1
43+ padded_len: multiple-of-8 length used in the MPI window
44+ """
45+ logical_len = data_length + 1
46+ padded_len = ((logical_len + 7 ) // 8 ) * 8
47+
48+ itemsize = np .dtype ("d" ).itemsize
49+ mem = MPI .Alloc_mem (padded_len * itemsize )
50+
51+ full_arr = np .frombuffer (mem , dtype = "d" , count = padded_len )
4152 full_arr [:] = np .nan
42-
43- # Return the full array for internal use, but record the logical size
44- arr = full_arr [: logical_size ]
45- arr [ - 1 ] = 0
46- return arr , size # Note: size here is the 'field_length * buffer_size'
53+
54+ logical_arr = full_arr [: logical_len ]
55+ logical_arr [ - 1 ] = 0.0
56+
57+ return full_arr , logical_arr , data_length , logical_len , padded_len
4758
4859
4960class FieldArray :
5061 """
51- Notes: Buffer that tracks new/old state as well. Light-weight wrapper around a numpy array.
52-
53- The intention here is that these are passive data holding classes. That is, other classes are
54- expected to update the internal fields. The lone exception to this is the read/write id field.
55- See the `SendArray` and `RecvArray` classes for how that field is updated.
62+ Wrapper around an MPI-allocated numpy buffer with:
63+ - a padded "window" array used for MPI RMA (Design A)
64+ - a logical view used by mpi-sppy code (data + id)
5665 """
5766
5867 def __init__ (self , length : int ):
59- # Store both the array (logical size + 1) and the original data length
60- self ._array , self ._data_length = communicator_array (length )
68+ # length is the data length (excluding the id)
69+ (self ._full_array ,
70+ self ._array ,
71+ self ._data_length ,
72+ self ._logical_len ,
73+ self ._padded_len ) = communicator_array (length )
6174 self ._id = 0
6275
63- def value_array (self ) -> np .typing .NDArray :
64- """ Returns only the data portion, excluding the ID field and padding """
65- return self ._array [:self ._data_length ]
66-
67- def __getitem__ (self , key ):
68- # TODO: Should probably be hiding the read/write id field but there are many functions
69- # that expect it to be there and being able to read it is not really a problem.
70- np_array = self .array ()
71- return np_array [key ]
76+ def window_array (self ) -> np .typing .NDArray :
77+ """Full padded array (used for SPWindow get/put)."""
78+ return self ._full_array
7279
7380 def array (self ) -> np .typing .NDArray :
74- """
75- Returns the numpy array for the field data including the read id
76- """
81+ """Logical array (data + id)."""
7782 return self ._array
7883
7984 def value_array (self ) -> np .typing .NDArray :
80- """
81- Returns the numpy array for the field data without the read id
82- """
83- return self ._array [:self ._logical_size ]
85+ """Data only (excludes id)."""
86+ return self ._array [:self ._data_length ]
87+
88+ def padded_len (self ) -> int :
89+ return self ._padded_len
90+
91+ def logical_len (self ) -> int :
92+ return self ._logical_len
93+
94+ def data_len (self ) -> int :
95+ return self ._data_length
96+
97+ def __getitem__ (self , key ):
98+ # Preserve old behavior: indexing into the logical view.
99+ return self ._array [key ]
84100
85101 def id (self ) -> int :
86102 return self ._id
87103
104+
88105class SendArray (FieldArray ):
89106
90107 def __init__ (self , length : int ):
@@ -105,10 +122,6 @@ def _next_write_id(self) -> int:
105122 """
106123 self ._id += 1
107124 self ._array [- 1 ] = self ._id
108- # This ensures the update to the ID (the malloc-sensitive boundary)
109- # is pushed to the window and synchronized before any other
110- # part of the code (like Gurobi or another MPI rank) touches it.
111- self .win .Flush (self .strata_rank )
112125 return self ._id
113126
114127
@@ -155,7 +168,8 @@ def __init__(self, data: FieldArray, field_length: int, buffer_size: int):
155168
156169 def _get_value_array (self , read_write_index ):
157170 position = read_write_index % self ._buffer_size
158- return self .data ._array [(position * self ._field_length ):((position + 1 )* self ._field_length )]
171+ arr = self .data .array () # logical view
172+ return arr [(position * self ._field_length ):((position + 1 )* self ._field_length )]
159173
160174
161175class SendCircularBuffer (_CircularBuffer ):
@@ -260,13 +274,12 @@ def _split_key(self, key) -> tuple[Field, int]:
260274 """
261275 return key
262276
263- def _build_window_spec (self ) -> dict [Field , int ]:
264- """ Build dict with fields and lengths needed for local MPI window
277+ def _build_window_spec (self ) -> dict [Field , tuple [ int , int ] ]:
278+ """ Build dict with fields and padded lengths needed for local MPI window
265279 """
266280 window_spec = dict ()
267- for (field ,buf ) in self .send_buffers .items ():
268- window_spec [field ] = np .size (buf .array ())
269- ## End for
281+ for (field , buf ) in self .send_buffers .items ():
282+ window_spec [field ] = (buf .logical_len (), buf .padded_len ())
270283 return window_spec
271284
272285 def _create_field_rank_mappings (self ) -> None :
@@ -277,7 +290,9 @@ def _create_field_rank_mappings(self) -> None:
277290 if rank == self .strata_rank :
278291 continue
279292 self .ranks_to_fields [rank ] = []
280- for field in buffer_layout :
293+ for field in buffer_layout .keys ():
294+ if field == Field .WHOLE :
295+ continue
281296 if field not in self .fields_to_ranks :
282297 self .fields_to_ranks [field ] = []
283298 self .fields_to_ranks [field ].append (rank )
@@ -288,18 +303,23 @@ def _create_field_rank_mappings(self) -> None:
288303 def _validate_recv_field (self , field : Field , origin : int , length : int ):
289304 remote_buffer_layout = self .window .strata_buffer_layouts [origin ]
290305 if field not in remote_buffer_layout :
291- raise RuntimeError (f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
292- f"could not find { field = } on remote rank { origin } with "
293- f"class { self .communicators [origin ]['spcomm_class' ]} ."
294- )
295- _ , remote_length = remote_buffer_layout [field ]
296- if (length + 1 ) != remote_length :
297- raise RuntimeError (f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
298- f"{ field = } has length { length } on local "
299- f"{ self .strata_rank = } and length { remote_length } "
300- f"on remote rank { origin } with class "
301- f"{ self .communicators [origin ]['spcomm_class' ]} ."
302- )
306+ raise RuntimeError (
307+ f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
308+ f"could not find { field = } on remote rank { origin } with "
309+ f"class { self .communicators [origin ]['spcomm_class' ]} ."
310+ )
311+
312+ _ , remote_logical_len , remote_padded_len = remote_buffer_layout [field ]
313+ expected_logical_len = length + 1
314+ expected_padded_len = ((expected_logical_len + 7 ) // 8 ) * 8
315+
316+ if remote_logical_len != expected_logical_len or remote_padded_len != expected_padded_len :
317+ raise RuntimeError (
318+ f"{ self .__class__ .__name__ } on local { self .strata_rank = } "
319+ f"{ field = } expects (logical={ expected_logical_len } , padded={ expected_padded_len } ) "
320+ f"but remote rank { origin } advertises (logical={ remote_logical_len } , padded={ remote_padded_len } ) "
321+ f"with class { self .communicators [origin ]['spcomm_class' ]} ."
322+ )
303323
304324 def register_recv_field (self , field : Field , origin : int , length : int = - 1 ) -> RecvArray :
305325 # print(f"{self.__class__.__name__}.register_recv_field, {field=}, {origin=}")
@@ -308,7 +328,10 @@ def register_recv_field(self, field: Field, origin: int, length: int = -1) -> Re
308328 length = self ._field_lengths [field ]
309329 if key in self .receive_buffers :
310330 my_fa = self .receive_buffers [key ]
311- assert (length + 1 == np .size (my_fa .array ()))
331+ expected_logical_len = length + 1
332+ expected_padded_len = ((expected_logical_len + 7 ) // 8 ) * 8
333+ assert expected_logical_len == my_fa .logical_len ()
334+ assert expected_padded_len == my_fa .padded_len ()
312335 else :
313336 self ._validate_recv_field (field , origin , length )
314337 my_fa = RecvArray (length )
@@ -407,7 +430,7 @@ def register_receive_fields(self) -> None:
407430 if strata_rank == self .strata_rank :
408431 continue
409432 cls = comm ["spcomm_class" ]
410- if field in self .ranks_to_fields [strata_rank ]:
433+ if field != Field . WHOLE and field in self .ranks_to_fields [strata_rank ]:
411434 buff = self .register_recv_field (field , strata_rank )
412435 self .receive_field_spcomms [field ].append ((strata_rank , cls , buff ))
413436
@@ -419,7 +442,7 @@ def put_send_buffer(self, buf: SendArray, field: Field):
419442 This automatically updates handles the write id.
420443 """
421444 buf ._next_write_id ()
422- self .window .put (buf .array (), field )
445+ self .window .put (buf .window_array (), field )
423446 return
424447
425448 def get_receive_buffer (self ,
@@ -449,9 +472,9 @@ def get_receive_buffer(self,
449472
450473 last_id = buf .id ()
451474
452- self .window .get (buf .array (), origin , field )
475+ self .window .get (buf .window_array (), origin , field ) # padded view
453476
454- new_id = int (buf .array ()[- 1 ])
477+ new_id = int (buf .array ()[- 1 ]) # logical view
455478
456479 if synchronize :
457480 local_val = np .array ((new_id ,), 'i' )
0 commit comments