Skip to content

Commit 6bf8c3a

Browse files
committed
cleanup datatype tmps for ialltoallw and friends
leverage nbc infrastructure Signed-off-by: Howard Pritchard <[email protected]>
1 parent 2d92fcd commit 6bf8c3a

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

ompi/mpi/bindings/ompi_bindings/c.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,10 @@ def generate_function(prototype, fn_name, internal_fn, out, enable_count=False):
881881
return_type = prototype.return_type.construct(abi_type='standard')
882882
lines.append(f'{return_type.tmp_type_text()} ret_value;')
883883
for param in params:
884+
if param.need_async_cleanup == True:
885+
lines.append('int idx = 0;')
886+
break
887+
for param in params:
884888
# print("param = " + str(param) + " " + str(param.argument))
885889
if param.init_code:
886890
lines.extend(param.init_code)

ompi/mpi/bindings/ompi_bindings/c_type.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,16 +90,24 @@ def callback_wrapper_code(self):
9090
"""Return True if this parameter has callback wrapper code to generate."""
9191
return False
9292

93+
@property
94+
def need_async_cleanup(self):
95+
"""Return True if this parameter generates async memory cleanup code."""
96+
return False
97+
9398
class StandardABIType(Type):
9499

95100
@property
96101
def tmpname(self):
97-
return f'{self.name}_tmp'
102+
return util.abi_tmp_name(self.name)
98103

99104
@property
100105
def argument(self):
101106
return self.tmpname
102107

108+
@staticmethod
109+
def async_callback_index(self):
110+
return "idx"
103111

104112
@Type.add_type('ERROR_CLASS', abi_type=['ompi'])
105113
class TypeErrorClass(Type):
@@ -660,13 +668,23 @@ def argument(self):
660668
@Type.add_type('DATATYPE_ARRAY_ASYNC', abi_type=['standard'])
661669
class TypeDatatypeArrayAsyncStandard(TypeDatatypeArrayStandard):
662670

671+
@property
672+
def need_async_cleanup(self):
673+
return True
674+
663675
@property
664676
def final_code(self):
665-
code = ['{']
677+
request_tmp_name = util.abi_tmp_name('request')
678+
code = []
679+
code.append('if((MPI_SUCCESS == ret_value) && (MPI_REQUEST_NULL != request_tmp)){')
680+
code.append(f'ompi_coll_base_nbc_request_t* nb_request = (ompi_coll_base_nbc_request_t*)&{request_tmp_name};')
681+
code.append(f'nb_request->data.release_arrays[idx++] = (void *){self.tmpname};')
682+
code.append('nb_request->data.release_arrays[idx] = NULL;')
683+
code.append('} else {')
684+
code.append(f'free({self.tmpname});')
666685
code.append('}')
667686
return code
668687

669-
670688
@Type.add_type('DATATYPE_ARRAY_OUT', abi_type=['standard'])
671689
class TypeDatatypeArrayOutStandard(StandardABIType):
672690

@@ -722,9 +740,20 @@ def init_code(self):
722740
@Type.add_type('NEIGHBOR_DATATYPE_ARRAY_ASYNC', abi_type=['standard'])
723741
class NeighborDatatypeArrayAsyncStandard(NeighborDatatypeArrayStandard):
724742

743+
@property
744+
def need_async_cleanup(self):
745+
return True
746+
725747
@property
726748
def final_code(self):
727-
code = ['{']
749+
request_tmp_name = util.abi_tmp_name('request')
750+
code = []
751+
code.append('if((MPI_SUCCESS == ret_value) && (MPI_REQUEST_NULL != request_tmp)){')
752+
code.append(f'ompi_coll_base_nbc_request_t* nb_request = (ompi_coll_base_nbc_request_t*)&{request_tmp_name};')
753+
code.append(f'nb_request->data.release_arrays[idx++] = (void *){self.tmpname};')
754+
code.append('nb_request->data.release_arrays[idx] = NULL;')
755+
code.append('} else {')
756+
code.append(f'free({self.tmpname});')
728757
code.append('}')
729758
return code
730759

ompi/mpi/bindings/ompi_bindings/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,6 @@ def prototype_needs_callback_wrappers(prototype):
197197
"""Should this prototype need a callback wrappers"""
198198
return any(param.type_name in USER_CALLBACK_NAMES for param in prototype.params)
199199

200+
def abi_tmp_name(name):
201+
"""Generate standardized tmp name for a supplied name"""
202+
return f'{name}_tmp'

0 commit comments

Comments
 (0)