Skip to content

Commit 459c3c5

Browse files
committed
add better support for MPI_ROOT and source
since MPI_ROOT for ABI is not equal to ompi. MPI_ANY_SOURCE is but still generate conversion code rather than just beining lucky. Signed-off-by: Howard Pritchard <[email protected]>
1 parent d60f88d commit 459c3c5

37 files changed

+88
-36
lines changed

ompi/mpi/bindings/ompi_bindings/c_type.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024 Triad National Security, LLC. All rights
1+
# Copyright (c) 2024-2025 Triad National Security, LLC. All rights
22
# reserved.
33
#
44
# $COPYRIGHT$
@@ -580,7 +580,7 @@ def type_text(self, enable_count=False):
580580
return 'int'
581581

582582
@Type.add_type('TAG', abi_type=['standard'])
583-
class TypeRank(StandardABIType):
583+
class TypeRankStandard(StandardABIType):
584584

585585
@property
586586
def init_code(self):
@@ -589,6 +589,38 @@ def init_code(self):
589589
def type_text(self, enable_count=False):
590590
return 'int'
591591

592+
@Type.add_type('ROOT', abi_type=['ompi'])
593+
class TypeRoot(Type):
594+
595+
def type_text(self, enable_count=False):
596+
return 'int'
597+
598+
@Type.add_type('ROOT', abi_type=['standard'])
599+
class TypeRootStandard(StandardABIType):
600+
601+
@property
602+
def init_code(self):
603+
return [f'int {self.tmpname} = {ConvertFuncs.ROOT}({self.name});']
604+
605+
def type_text(self, enable_count=False):
606+
return 'int'
607+
608+
@Type.add_type('SOURCE', abi_type=['ompi'])
609+
class TypeSource(Type):
610+
611+
def type_text(self, enable_count=False):
612+
return 'int'
613+
614+
@Type.add_type('SOURCE', abi_type=['standard'])
615+
class TypeSourceStandard(StandardABIType):
616+
617+
@property
618+
def init_code(self):
619+
return [f'int {self.tmpname} = {ConvertFuncs.SOURCE}({self.name});']
620+
621+
def type_text(self, enable_count=False):
622+
return 'int'
623+
592624
@Type.add_type('COMM', abi_type=['ompi'])
593625
class TypeCommunicator(Type):
594626

ompi/mpi/bindings/ompi_bindings/consts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ class ConvertFuncs:
280280
FILE = 'ompi_convert_abi_file_intern_file'
281281
TS_LEVEL = 'ompi_convert_abi_ts_level_intern_ts_level'
282282
TAG = 'ompi_convert_abi_tag_intern_tag'
283+
ROOT = 'ompi_convert_abi_root_intern_root'
284+
SOURCE = 'ompi_convert_abi_source_intern_source'
283285

284286

285287
class ConvertOMPIToStandard:

ompi/mpi/c/abi_converters.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,24 @@ __opal_attribute_always_inline__ static inline int ompi_convert_tag_ompi_to_star
999999
}
10001000
}
10011001

1002+
__opal_attribute_always_inline__ static inline int ompi_convert_abi_root_intern_root(int root)
1003+
{
1004+
if (MPI_ROOT_ABI_INTERNAL == root) {
1005+
return MPI_ROOT;
1006+
} else {
1007+
return root;
1008+
}
1009+
}
1010+
1011+
__opal_attribute_always_inline__ static inline int ompi_convert_abi_source_intern_source(int source)
1012+
{
1013+
if (MPI_ANY_SOURCE_ABI_INTERNAL == source) {
1014+
return MPI_ANY_SOURCE;
1015+
} else {
1016+
return source;
1017+
}
1018+
}
1019+
10021020
#if defined(c_plusplus) || defined(__cplusplus)
10031021
}
10041022
#endif

ompi/mpi/c/bcast.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
#include "ompi/runtime/ompi_spc.h"
3333

3434
PROTOTYPE ERROR_CLASS bcast(BUFFER_OUT buffer, COUNT count, DATATYPE datatype,
35-
INT root, COMM comm)
35+
ROOT root, COMM comm)
3636
{
3737
int err;
3838

ompi/mpi/c/bcast_init.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include "ompi/runtime/ompi_spc.h"
2727

2828
PROTOTYPE ERROR_CLASS bcast_init(BUFFER_OUT buffer, COUNT count, DATATYPE datatype,
29-
INT root, COMM comm, INFO info, REQUEST_INOUT request)
29+
ROOT root, COMM comm, INFO info, REQUEST_INOUT request)
3030
{
3131
int err;
3232

ompi/mpi/c/comm_accept.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
#include "ompi/dpm/dpm.h"
4040
#include "ompi/memchecker.h"
4141

42-
PROTOTYPE ERROR_CLASS comm_accept(STRING port_name, INFO info, INT root,
42+
PROTOTYPE ERROR_CLASS comm_accept(STRING port_name, INFO info, ROOT root,
4343
COMM comm, COMM_OUT newcomm)
4444
{
4545
int rank, rc;

ompi/mpi/c/comm_connect.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
#include "ompi/dpm/dpm.h"
4040
#include "ompi/memchecker.h"
4141

42-
PROTOTYPE ERROR_CLASS comm_connect(STRING port_name, INFO info, INT root,
42+
PROTOTYPE ERROR_CLASS comm_connect(STRING port_name, INFO info, ROOT root,
4343
COMM comm, COMM_OUT newcomm)
4444
{
4545
int rank, rc;

ompi/mpi/c/comm_spawn.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
#include "ompi/memchecker.h"
4343

4444
PROTOTYPE ERROR_CLASS comm_spawn(STRING command, STRING_ARRAY argv, INT maxprocs, INFO info,
45-
INT root, COMM comm, COMM_OUT intercomm,
45+
ROOT root, COMM comm, COMM_OUT intercomm,
4646
INT_OUT array_of_errcodes)
4747
{
4848
int rank, rc=OMPI_SUCCESS, i, flag;

ompi/mpi/c/comm_spawn_multiple.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
PROTOTYPE ERROR_CLASS comm_spawn_multiple(INT count, STRING_ARRAY array_of_commands, ARGV array_of_argv,
4545
INT_ARRAY array_of_maxprocs, INFO_ARRAY array_of_info:count,
46-
INT root, COMM comm, COMM_OUT intercomm,
46+
ROOT root, COMM comm, COMM_OUT intercomm,
4747
INT_OUT array_of_errcodes)
4848
{
4949
int i=0, rc=0, rank=0, size=0, flag;

ompi/mpi/c/gather.c.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
PROTOTYPE ERROR_CLASS gather(BUFFER sendbuf, COUNT sendcount, DATATYPE sendtype,
4141
BUFFER_OUT recvbuf, COUNT recvcount, DATATYPE recvtype,
42-
INT root, COMM comm)
42+
ROOT root, COMM comm)
4343
{
4444
int err;
4545

0 commit comments

Comments
 (0)