@@ -16,20 +16,32 @@ typedef enum {
16
16
_MPI_INVALID ,
17
17
_MPI_Send_idx ,
18
18
_MPI_Recv_idx ,
19
+ _MPI_Allreduce_idx ,
20
+ _MPI_Reduce_idx ,
19
21
} offload_emis_mpi_t ;
20
22
21
23
/// Device stubs that call _emissary_exec using identical host API interface
22
24
#if defined(__NVPTX__ ) || defined(__AMDGCN__ )
23
25
extern "C" int MPI_Send (const void * buf , int count , MPI_Datatype datatype ,
24
26
int dest , int tag , MPI_Comm comm ) {
25
- return (int )_emissary_exec (_PACK_EMIS_IDS (EMIS_ID_MPI , _MPI_Send_idx ), buf ,
26
- count , datatype , dest , tag , comm );
27
+ return (int )_emissary_exec (_PACK_EMIS_IDS (EMIS_ID_MPI , _MPI_Send_idx ),
28
+ buf , count , datatype , dest , tag , comm );
27
29
}
28
30
extern "C" int MPI_Recv (void * buf , int count , MPI_Datatype datatype , int source ,
29
31
int tag , MPI_Comm comm , MPI_Status * st ) {
30
32
return (int )_emissary_exec (_PACK_EMIS_IDS (EMIS_ID_MPI , _MPI_Recv_idx ), buf ,
31
33
count , datatype , source , tag , comm , st );
32
34
}
35
+ extern "C" int MPI_Allreduce (const void * sendbuf , void * recvbuf , int count ,
36
+ MPI_Datatype datatype , MPI_Op op , MPI_Comm comm ) {
37
+ return (int )_emissary_exec (_PACK_EMIS_IDS (EMIS_ID_MPI , _MPI_Allreduce_idx ),
38
+ sendbuf , recvbuf , count , datatype , op , comm );
39
+ }
40
+ extern "C" int MPI_Reduce (const void * sendbuf , void * recvbuf , int count , MPI_Datatype datatype ,
41
+ MPI_Op op , int root , MPI_Comm comm ) {
42
+ return (int )_emissary_exec (_PACK_EMIS_IDS (EMIS_ID_MPI , _MPI_Reduce_idx ),
43
+ sendbuf , recvbuf , count , datatype , op , root , comm );
44
+ }
33
45
#endif
34
46
35
47
/// Host variadic wrapper functions.
@@ -61,6 +73,35 @@ extern int V_MPI_Recv(void *fnptr, ...) {
61
73
int rval = MPI_Recv (v0 , v1 , v2 , v3 , v4 , v5 , v6 );
62
74
return rval ;
63
75
}
76
+ extern int V_MPI_Allreduce (void * fnptr , ...) {
77
+ va_list args ;
78
+ va_start (args , fnptr );
79
+ void * buf = va_arg (args , void * );
80
+ void * recvbuf = va_arg (args , void * );
81
+ int count = va_arg (args , int );
82
+ MPI_Datatype datatype = va_arg (args , MPI_Datatype );
83
+ MPI_Op op = va_arg (args , MPI_Op );
84
+ MPI_Comm comm = va_arg (args , MPI_Comm );
85
+ va_end (args );
86
+ int rval = MPI_Allreduce (
87
+ buf , recvbuf , count , datatype , op , comm );
88
+ return rval ;
89
+ }
90
+ extern int V_MPI_Reduce (void * fnptr , ...) {
91
+ va_list args ;
92
+ va_start (args , fnptr );
93
+ void * buf = va_arg (args , void * );
94
+ void * recvbuf = va_arg (args , void * );
95
+ int count = va_arg (args , int );
96
+ MPI_Datatype datatype = va_arg (args , MPI_Datatype );
97
+ MPI_Op op = va_arg (args , MPI_Op );
98
+ int root = va_arg (args , int );
99
+ MPI_Comm comm = va_arg (args , MPI_Comm );
100
+ va_end (args );
101
+ int rval = MPI_Reduce (
102
+ buf , recvbuf , count , datatype , op , root , comm );
103
+ return rval ;
104
+ }
64
105
65
106
/// EmissaryMPI function selector
66
107
emis_return_t EmissaryMPI (char * data , emisArgBuf_t * ab , emis_argptr_t * a []) {
@@ -78,6 +119,18 @@ emis_return_t EmissaryMPI(char *data, emisArgBuf_t *ab, emis_argptr_t *a[]) {
78
119
V_MPI_Recv (fnptr , a [0 ], a [1 ], a [2 ], a [3 ], a [4 ], a [5 ], a [6 ]);
79
120
return (emis_return_t )return_value_int ;
80
121
}
122
+ case _MPI_Allreduce_idx : {
123
+ void * fnptr = (void * )V_MPI_Allreduce ;
124
+ int return_value_int =
125
+ V_MPI_Allreduce (fnptr , a [0 ], a [1 ], a [2 ], a [3 ], a [4 ], a [5 ]);
126
+ return (emis_return_t ) return_value_int ;
127
+ }
128
+ case _MPI_Reduce_idx : {
129
+ void * fnptr = (void * )V_MPI_Reduce ;
130
+ int return_value_int =
131
+ V_MPI_Reduce (fnptr , a [0 ], a [1 ], a [2 ], a [3 ], a [4 ], a [5 ], a [6 ]);
132
+ return (emis_return_t ) return_value_int ;
133
+ }
81
134
}
82
135
return (emis_return_t )0 ;
83
136
}
0 commit comments