Skip to content

Commit 7f41d88

Browse files
committed
add SOURCE_ARRAY type
to help out MPI_Group_translate_ranks Signed-off-by: Howard Pritchard <[email protected]>
1 parent 9c1147b commit 7f41d88

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

ompi/mpi/bindings/ompi_bindings/c_type.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,42 @@ def init_code(self):
892892
def type_text(self, enable_count=False):
893893
return 'int'
894894

895+
@Type.add_type('SOURCE_ARRAY', abi_type=['ompi'])
896+
class TypeSourceArray(Type):
897+
898+
def type_text(self, enable_count=False):
899+
return 'const int*'
900+
901+
def parameter(self, enable_count=False, **kwargs):
902+
return f'const int {self.name}[]'
903+
904+
@Type.add_type('SOURCE_ARRAY', abi_type=['standard'])
905+
class TypeSourceArrayStandard(StandardABIType):
906+
907+
@property
908+
def init_code(self):
909+
code = [(f'int *{self.tmpname} = NULL;')]
910+
code.append('if('+f'{self.name}' + '!= NULL)' + '{')
911+
code.append(f'{self.tmpname} = (int *)malloc(sizeof(int) * {self.count_param});')
912+
code.append(f'for(int i=0;i<{self.count_param};i++){{')
913+
code.append(f'{self.tmpname}[i] = {ConvertFuncs.SOURCE}({self.name}[i]);')
914+
code.append('}')
915+
code.append('}')
916+
return code
917+
918+
@property
919+
def final_code(self):
920+
code = [f'if({self.tmpname} != NULL){{']
921+
code.append(f'free({self.tmpname});')
922+
code.append('}')
923+
return code
924+
925+
def type_text(self, enable_count=False):
926+
return 'int *'
927+
928+
def parameter(self, enable_count=False, **kwargs):
929+
return f'const int {self.name}[]'
930+
895931
@Type.add_type('SOURCE_OUT', abi_type=['ompi'])
896932
class TypeSourceOut(Type):
897933

@@ -912,6 +948,39 @@ def type_text(self, enable_count=False):
912948
def argument(self):
913949
return f'(int *) {self.name}'
914950

951+
@Type.add_type('SOURCE_ARRAY_OUT', abi_type=['ompi'])
952+
class TypeSourceArrayOut(Type):
953+
954+
def type_text(self, enable_count=False):
955+
return 'int *'
956+
957+
def parameter(self, enable_count=False, **kwargs):
958+
return f'int {self.name}[]'
959+
960+
@Type.add_type('SOURCE_ARRAY_OUT', abi_type=['standard'])
961+
class TypeSourceArrayOutStandard(StandardABIType):
962+
963+
@property
964+
def init_code(self):
965+
code = [f'int *{self.tmpname} = (int*)malloc({self.count_param} * sizeof(int));']
966+
return code
967+
968+
@property
969+
def final_code(self):
970+
code = [f'if (NULL != {self.name}){{']
971+
code.append(f'for(int i=0;i<{self.count_param};i++){{')
972+
code.append(f'{self.name}[i] = {ConvertOMPIToStandard.SOURCE}({self.tmpname}[i]);')
973+
code.append('}')
974+
code.append('}')
975+
code.append(f'free({self.tmpname});')
976+
return code
977+
978+
def type_text(self, enable_count=False):
979+
return 'int *'
980+
981+
def parameter(self, enable_count=False, **kwargs):
982+
return f'int {self.name}[]'
983+
915984
@Type.add_type('COMM', abi_type=['ompi'])
916985
class TypeCommunicator(Type):
917986

ompi/mpi/c/group_translate_ranks.c.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
#include "ompi/errhandler/errhandler.h"
3333
#include "ompi/group/group.h"
3434

35-
PROTOTYPE ERROR_CLASS group_translate_ranks(GROUP group1, INT n_ranks, INT_ARRAY ranks1,
36-
GROUP group2, INT_OUT ranks2)
35+
PROTOTYPE ERROR_CLASS group_translate_ranks(GROUP group1, INT n_ranks, SOURCE_ARRAY ranks1:n_ranks,
36+
GROUP group2, SOURCE_ARRAY_OUT ranks2:n_ranks)
3737
{
3838
int err;
3939

0 commit comments

Comments
 (0)