Skip to content

Commit 309a88a

Browse files
committed
emissary-mpi reduce operation
1 parent d4a3b12 commit 309a88a

File tree

1 file changed

+55
-2
lines changed

1 file changed

+55
-2
lines changed

offload/DeviceRTL/include/EmissaryMPI.h

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,32 @@ typedef enum {
1616
_MPI_INVALID,
1717
_MPI_Send_idx,
1818
_MPI_Recv_idx,
19+
_MPI_Allreduce_idx,
20+
_MPI_Reduce_idx,
1921
} offload_emis_mpi_t;
2022

2123
/// Device stubs that call _emissary_exec using identical host API interface
2224
#if defined(__NVPTX__) || defined(__AMDGCN__)
2325
extern "C" int MPI_Send(const void *buf, int count, MPI_Datatype datatype,
2426
int dest, int tag, MPI_Comm comm) {
25-
return (int)_emissary_exec(_PACK_EMIS_IDS(EMIS_ID_MPI, _MPI_Send_idx), buf,
26-
count, datatype, dest, tag, comm);
27+
return (int)_emissary_exec(_PACK_EMIS_IDS(EMIS_ID_MPI, _MPI_Send_idx),
28+
buf, count, datatype, dest, tag, comm);
2729
}
2830
extern "C" int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source,
2931
int tag, MPI_Comm comm, MPI_Status *st) {
3032
return (int)_emissary_exec(_PACK_EMIS_IDS(EMIS_ID_MPI, _MPI_Recv_idx), buf,
3133
count, datatype, source, tag, comm, st);
3234
}
35+
extern "C" int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
36+
MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) {
37+
return (int)_emissary_exec(_PACK_EMIS_IDS(EMIS_ID_MPI, _MPI_Allreduce_idx),
38+
sendbuf, recvbuf, count, datatype, op, comm);
39+
}
40+
extern "C" int MPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
41+
MPI_Op op, int root, MPI_Comm comm) {
42+
return (int)_emissary_exec(_PACK_EMIS_IDS(EMIS_ID_MPI, _MPI_Reduce_idx),
43+
sendbuf, recvbuf, count, datatype, op, root, comm);
44+
}
3345
#endif
3446

3547
/// Host variadic wrapper functions.
@@ -61,6 +73,35 @@ extern int V_MPI_Recv(void *fnptr, ...) {
6173
int rval = MPI_Recv(v0, v1, v2, v3, v4, v5, v6);
6274
return rval;
6375
}
76+
extern int V_MPI_Allreduce(void *fnptr, ...) {
77+
va_list args;
78+
va_start(args, fnptr);
79+
void *buf = va_arg(args, void *);
80+
void *recvbuf = va_arg(args, void *);
81+
int count = va_arg(args, int);
82+
MPI_Datatype datatype = va_arg(args, MPI_Datatype);
83+
MPI_Op op = va_arg(args, MPI_Op);
84+
MPI_Comm comm = va_arg(args, MPI_Comm);
85+
va_end(args);
86+
int rval = MPI_Allreduce(
87+
buf, recvbuf, count, datatype, op, comm);
88+
return rval;
89+
}
90+
extern int V_MPI_Reduce(void *fnptr, ...) {
91+
va_list args;
92+
va_start(args, fnptr);
93+
void *buf = va_arg(args, void *);
94+
void *recvbuf = va_arg(args, void *);
95+
int count = va_arg(args, int);
96+
MPI_Datatype datatype = va_arg(args, MPI_Datatype);
97+
MPI_Op op = va_arg(args, MPI_Op);
98+
int root = va_arg(args, int);
99+
MPI_Comm comm = va_arg(args, MPI_Comm);
100+
va_end(args);
101+
int rval = MPI_Reduce(
102+
buf, recvbuf, count, datatype, op, root, comm);
103+
return rval;
104+
}
64105

65106
/// EmissaryMPI function selector
66107
emis_return_t EmissaryMPI(char *data, emisArgBuf_t *ab, emis_argptr_t *a[]) {
@@ -78,6 +119,18 @@ emis_return_t EmissaryMPI(char *data, emisArgBuf_t *ab, emis_argptr_t *a[]) {
78119
V_MPI_Recv(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6]);
79120
return (emis_return_t)return_value_int;
80121
}
122+
case _MPI_Allreduce_idx: {
123+
void *fnptr = (void *)V_MPI_Allreduce;
124+
int return_value_int =
125+
V_MPI_Allreduce(fnptr, a[0], a[1], a[2], a[3], a[4], a[5]);
126+
return (emis_return_t) return_value_int;
127+
}
128+
case _MPI_Reduce_idx: {
129+
void *fnptr = (void *)V_MPI_Reduce;
130+
int return_value_int =
131+
V_MPI_Reduce(fnptr, a[0], a[1], a[2], a[3], a[4], a[5], a[6]);
132+
return (emis_return_t) return_value_int;
133+
}
81134
}
82135
return (emis_return_t)0;
83136
}

0 commit comments

Comments
 (0)