Skip to content

Commit 3986f25

Browse files
authored
[RO] Fix MPI Communicator Check (#203)
* Fix MPI Communicator casting [ROCm/rocshmem commit: b341285]
1 parent 2a7416d commit 3986f25

File tree

5 files changed

+48
-48
lines changed

5 files changed

+48
-48
lines changed

projects/rocshmem/src/reverse_offload/context_ro_device.cpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ __device__ void ROContext::putmem(void *dest, const void *source, size_t nelems,
7878
return;
7979
}
8080
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
81-
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
81+
pe, 0, 0, 0, nullptr, nullptr, NULL,
8282
ro_net_win_id, block_handle, true, get_status_flag(),
8383
is_default_ctx);
8484
}
@@ -98,7 +98,7 @@ __device__ void ROContext::getmem(void *dest, const void *source, size_t nelems,
9898
return;
9999
}
100100
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
101-
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
101+
pe, 0, 0, 0, nullptr, nullptr, NULL,
102102
ro_net_win_id, block_handle, true, get_status_flag(),
103103
is_default_ctx);
104104
}
@@ -118,7 +118,7 @@ __device__ void ROContext::putmem_nbi(void *dest, const void *source,
118118
return;
119119
}
120120
build_queue_element(RO_NET_PUT_NBI, dest, const_cast<void *>(source),
121-
nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
121+
nelems, pe, 0, 0, 0, nullptr, nullptr, NULL,
122122
ro_net_win_id, block_handle, false);
123123
}
124124
}
@@ -137,27 +137,27 @@ __device__ void ROContext::getmem_nbi(void *dest, const void *source,
137137
return;
138138
}
139139
build_queue_element(RO_NET_GET_NBI, dest, const_cast<void *>(source),
140-
nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
140+
nelems, pe, 0, 0, 0, nullptr, nullptr, NULL,
141141
ro_net_win_id, block_handle, false);
142142
}
143143
}
144144

145145
__device__ void ROContext::fence() {
146146
build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
147-
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
147+
nullptr, NULL, ro_net_win_id, block_handle,
148148
true, get_status_flag(), is_default_ctx);
149149
}
150150

151151
__device__ void ROContext::fence(int pe) {
152152
// TODO(khamidou): need to check if per pe has any special handling
153153
build_queue_element(RO_NET_FENCE, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
154-
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
154+
nullptr, NULL, ro_net_win_id, block_handle,
155155
true, get_status_flag(), is_default_ctx);
156156
}
157157

158158
__device__ void ROContext::quiet() {
159159
build_queue_element(RO_NET_QUIET, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
160-
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
160+
nullptr, NULL, ro_net_win_id, block_handle,
161161
true, get_status_flag(), is_default_ctx);
162162
}
163163

@@ -175,22 +175,22 @@ __device__ void *ROContext::shmem_ptr(const void *dest, int pe) {
175175

176176
__device__ void ROContext::barrier_all() {
177177
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
178-
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
178+
nullptr, nullptr, NULL, ro_net_win_id,
179179
block_handle, true, get_status_flag(), is_default_ctx);
180180
}
181181

182182
__device__ void ROContext::barrier_all_wave() {
183183
if (is_thread_zero_in_wave()) {
184184
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
185-
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
185+
nullptr, nullptr, NULL, ro_net_win_id,
186186
block_handle, true, get_status_flag(), is_default_ctx);
187187
}
188188
}
189189

190190
__device__ void ROContext::barrier_all_wg() {
191191
if (is_thread_zero_in_block()) {
192192
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0,
193-
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
193+
nullptr, nullptr, NULL, ro_net_win_id,
194194
block_handle, true, get_status_flag(), is_default_ctx);
195195
}
196196
__syncthreads();
@@ -199,15 +199,15 @@ __device__ void ROContext::barrier_all_wg() {
199199
__device__ void ROContext::barrier(rocshmem_team_t team) {
200200
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
201201
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
202-
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
202+
nullptr, (intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle,
203203
true, get_status_flag(), is_default_ctx);
204204
}
205205

206206
__device__ void ROContext::barrier_wave(rocshmem_team_t team) {
207207
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
208208
if (is_thread_zero_in_wave()) {
209209
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
210-
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
210+
nullptr, (intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle,
211211
true, get_status_flag(), is_default_ctx);
212212
}
213213
}
@@ -216,30 +216,30 @@ __device__ void ROContext::barrier_wg(rocshmem_team_t team) {
216216
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
217217
if (is_thread_zero_in_block()) {
218218
build_queue_element(RO_NET_BARRIER, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
219-
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
219+
nullptr, (intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle,
220220
true, get_status_flag(), is_default_ctx);
221221
}
222222
__syncthreads();
223223
}
224224

225225
__device__ void ROContext::sync_all() {
226226
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
227-
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
227+
nullptr, nullptr, NULL, ro_net_win_id,
228228
block_handle, true, get_status_flag(), is_default_ctx);
229229
}
230230

231231
__device__ void ROContext::sync_all_wave() {
232232
if (is_thread_zero_in_wave()) {
233233
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
234-
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
234+
nullptr, nullptr, NULL, ro_net_win_id,
235235
block_handle, true, get_status_flag(), is_default_ctx);
236236
}
237237
}
238238

239239
__device__ void ROContext::sync_all_wg() {
240240
if (is_thread_zero_in_block()) {
241241
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0,
242-
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
242+
nullptr, nullptr, NULL, ro_net_win_id,
243243
block_handle, true, get_status_flag(), is_default_ctx);
244244
}
245245
__syncthreads();
@@ -248,15 +248,15 @@ __device__ void ROContext::sync_all_wg() {
248248
__device__ void ROContext::sync(rocshmem_team_t team) {
249249
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
250250
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
251-
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
251+
nullptr, (intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle,
252252
true, get_status_flag(), is_default_ctx);
253253
}
254254

255255
__device__ void ROContext::sync_wave(rocshmem_team_t team) {
256256
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
257257
if (is_thread_zero_in_wave()) {
258258
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
259-
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
259+
nullptr, (intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle,
260260
true, get_status_flag(), is_default_ctx);
261261
}
262262
}
@@ -265,7 +265,7 @@ __device__ void ROContext::sync_wg(rocshmem_team_t team) {
265265
ROTeam *team_obj = reinterpret_cast<ROTeam *>(team);
266266
if (is_thread_zero_in_block()) {
267267
build_queue_element(RO_NET_SYNC, nullptr, nullptr, 0, 0, 0, 0, 0, nullptr,
268-
nullptr, team_obj->mpi_comm, ro_net_win_id, block_handle,
268+
nullptr, (intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle,
269269
true, get_status_flag(), is_default_ctx);
270270
}
271271
__syncthreads();
@@ -278,7 +278,7 @@ __device__ void ROContext::ctx_destroy() {
278278
auto *proxy{backend_proxy.get()};
279279

280280
build_queue_element(RO_NET_FINALIZE, nullptr, nullptr, 0, 0, 0, 0, 0,
281-
nullptr, nullptr, (MPI_Comm)NULL, ro_net_win_id,
281+
nullptr, nullptr, NULL, ro_net_win_id,
282282
block_handle, true, get_status_flag(), is_default_ctx);
283283

284284
int buffer_id = ro_net_win_id;
@@ -302,7 +302,7 @@ __device__ void ROContext::putmem_wg(void *dest, const void *source,
302302
} else {
303303
if (is_thread_zero_in_block()) {
304304
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
305-
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
305+
pe, 0, 0, 0, nullptr, nullptr, NULL,
306306
ro_net_win_id, block_handle, true, get_status_flag(),
307307
is_default_ctx);
308308
}
@@ -321,7 +321,7 @@ __device__ void ROContext::getmem_wg(void *dest, const void *source,
321321
} else {
322322
if (is_thread_zero_in_block()) {
323323
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
324-
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
324+
pe, 0, 0, 0, nullptr, nullptr, NULL,
325325
ro_net_win_id, block_handle, true, get_status_flag(),
326326
is_default_ctx);
327327
}
@@ -340,7 +340,7 @@ __device__ void ROContext::putmem_nbi_wg(void *dest, const void *source,
340340
} else {
341341
if (is_thread_zero_in_block()) {
342342
build_queue_element(RO_NET_PUT_NBI, dest, const_cast<void *>(source),
343-
nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
343+
nelems, pe, 0, 0, 0, nullptr, nullptr, NULL,
344344
ro_net_win_id, block_handle, false);
345345
}
346346
}
@@ -358,7 +358,7 @@ __device__ void ROContext::getmem_nbi_wg(void *dest, const void *source,
358358
} else {
359359
if (is_thread_zero_in_block()) {
360360
build_queue_element(RO_NET_GET_NBI, dest, const_cast<void *>(source),
361-
nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
361+
nelems, pe, 0, 0, 0, nullptr, nullptr, NULL,
362362
ro_net_win_id, block_handle, false);
363363
}
364364
}
@@ -376,7 +376,7 @@ __device__ void ROContext::putmem_wave(void *dest, const void *source,
376376
} else {
377377
if (is_thread_zero_in_wave()) {
378378
build_queue_element(RO_NET_PUT, dest, const_cast<void *>(source), nelems,
379-
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
379+
pe, 0, 0, 0, nullptr, nullptr, NULL,
380380
ro_net_win_id, block_handle, true, get_status_flag(),
381381
is_default_ctx);
382382
}
@@ -395,7 +395,7 @@ __device__ void ROContext::getmem_wave(void *dest, const void *source,
395395
} else {
396396
if (is_thread_zero_in_wave()) {
397397
build_queue_element(RO_NET_GET, dest, const_cast<void *>(source), nelems,
398-
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
398+
pe, 0, 0, 0, nullptr, nullptr, NULL,
399399
ro_net_win_id, block_handle, true, get_status_flag(),
400400
is_default_ctx);
401401
}
@@ -413,7 +413,7 @@ __device__ void ROContext::putmem_nbi_wave(void *dest, const void *source,
413413
} else {
414414
if (is_thread_zero_in_wave()) {
415415
build_queue_element(RO_NET_PUT_NBI, dest, const_cast<void *>(source),
416-
nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
416+
nelems, pe, 0, 0, 0, nullptr, nullptr, NULL,
417417
ro_net_win_id, block_handle, false);
418418
}
419419
}
@@ -431,7 +431,7 @@ __device__ void ROContext::getmem_nbi_wave(void *dest, const void *source,
431431
} else {
432432
if (is_thread_zero_in_wave()) {
433433
build_queue_element(RO_NET_GET_NBI, dest, const_cast<void *>(source),
434-
nelems, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
434+
nelems, pe, 0, 0, 0, nullptr, nullptr, NULL,
435435
ro_net_win_id, block_handle, false);
436436
}
437437
}
@@ -665,7 +665,7 @@ __device__ uint64_t next_write_slot(BlockHandle *handle) {
665665
__device__ void build_queue_element(
666666
ro_net_cmds type, void *dst, void *src, size_t size, int pe,
667667
int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync,
668-
MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle,
668+
intptr_t team_comm, int ro_net_win_id, BlockHandle *handle,
669669
bool blocking, volatile char *status, bool default_ctx, ROCSHMEM_OP op,
670670
ro_net_types datatype) {
671671
auto write_slot{next_write_slot(handle)};

projects/rocshmem/src/reverse_offload/context_ro_device.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace rocshmem {
3535
__device__ void build_queue_element(
3636
ro_net_cmds type, void *dst, void *src, size_t size, int pe,
3737
int logPE_stride, int PE_size, int PE_root, void *pWrk, long *pSync,
38-
MPI_Comm team_comm, int ro_net_win_id, BlockHandle *handle,
38+
intptr_t team_comm, int ro_net_win_id, BlockHandle *handle,
3939
bool blocking, volatile char *status = nullptr, bool default_ctx = false,
4040
ROCSHMEM_OP op = ROCSHMEM_SUM, ro_net_types datatype = RO_NET_INT);
4141

projects/rocshmem/src/reverse_offload/context_ro_tmpl_device.hpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ __device__ int ROContext::reduce(rocshmem_team_t team, T *dest,
121121
ROTeam *team_obj{reinterpret_cast<ROTeam *>(team)};
122122

123123
build_queue_element(RO_NET_TEAM_REDUCE, dest, const_cast<T *>(source),
124-
nreduce, 0, 0, 0, 0, nullptr, nullptr, team_obj->mpi_comm,
124+
nreduce, 0, 0, 0, 0, nullptr, nullptr, (intptr_t)team_obj->mpi_comm,
125125
ro_net_win_id, block_handle, true, get_status_flag(),
126126
is_default_ctx, Op, GetROType<T>::Type);
127127

@@ -152,7 +152,7 @@ __device__ void ROContext::p(T *dest, T value, int pe) {
152152
reinterpret_cast<void *>(&value), sizeof(T));
153153
} else {
154154
build_queue_element(RO_NET_P, dest, &value, sizeof(T), pe, 0, 0, 0, nullptr,
155-
nullptr, (MPI_Comm)NULL, ro_net_win_id,
155+
nullptr, NULL, ro_net_win_id,
156156
block_handle, true, get_status_flag(), is_default_ctx);
157157
}
158158
}
@@ -196,7 +196,7 @@ __device__ T ROContext::amo_fetch_cas(void *dst, T value, T cond, int pe) {
196196
build_queue_element(RO_NET_AMO_FCAS, dst, reinterpret_cast<T *>(source),
197197
value, pe, 0, 0, 0,
198198
reinterpret_cast<void *>(static_cast<long long>(cond)),
199-
nullptr, (MPI_Comm)NULL, ro_net_win_id, block_handle,
199+
nullptr, NULL, ro_net_win_id, block_handle,
200200
true, get_status_flag(), is_default_ctx, ROCSHMEM_SUM,
201201
GetROType<T>::Type);
202202
__threadfence();
@@ -215,7 +215,7 @@ template <typename T>
215215
__device__ T ROContext::amo_fetch_add(void *dst, T value, int pe) {
216216
auto source{get_atomic_ret_buf()};
217217
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<T *>(source), value,
218-
pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
218+
pe, 0, 0, 0, nullptr, nullptr, NULL,
219219
ro_net_win_id, block_handle, true, get_status_flag(),
220220
is_default_ctx, ROCSHMEM_SUM, GetROType<T>::Type);
221221
__threadfence();
@@ -234,7 +234,7 @@ template <typename T>
234234
__device__ T ROContext::amo_swap(void *dst, T value, int pe) {
235235
auto source{get_atomic_ret_buf()};
236236
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
237-
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
237+
value, pe, 0, 0, 0, nullptr, nullptr, NULL,
238238
ro_net_win_id, block_handle, true, get_status_flag(),
239239
is_default_ctx, ROCSHMEM_REPLACE, GetROType<T>::Type);
240240
__threadfence();
@@ -253,7 +253,7 @@ template <typename T>
253253
__device__ T ROContext::amo_fetch_and(void *dst, T value, int pe) {
254254
auto source{get_atomic_ret_buf()};
255255
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
256-
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
256+
value, pe, 0, 0, 0, nullptr, nullptr, NULL,
257257
ro_net_win_id, block_handle, true, get_status_flag(),
258258
is_default_ctx, ROCSHMEM_AND, GetROType<T>::Type);
259259
__threadfence();
@@ -272,7 +272,7 @@ template <typename T>
272272
__device__ T ROContext::amo_fetch_or(void *dst, T value, int pe) {
273273
auto source{get_atomic_ret_buf()};
274274
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
275-
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
275+
value, pe, 0, 0, 0, nullptr, nullptr, NULL,
276276
ro_net_win_id, block_handle, true, get_status_flag(),
277277
is_default_ctx, ROCSHMEM_OR, GetROType<T>::Type);
278278
__threadfence();
@@ -291,7 +291,7 @@ template <typename T>
291291
__device__ T ROContext::amo_fetch_xor(void *dst, T value, int pe) {
292292
auto source{get_atomic_ret_buf()};
293293
build_queue_element(RO_NET_AMO_FOP, dst, reinterpret_cast<void *>(source),
294-
value, pe, 0, 0, 0, nullptr, nullptr, (MPI_Comm)NULL,
294+
value, pe, 0, 0, 0, nullptr, nullptr, NULL,
295295
ro_net_win_id, block_handle, true, get_status_flag(),
296296
is_default_ctx, ROCSHMEM_XOR, GetROType<T>::Type);
297297
__threadfence();
@@ -318,7 +318,7 @@ __device__ void ROContext::broadcast(rocshmem_team_t team, T *dest,
318318

319319
build_queue_element(RO_NET_TEAM_BROADCAST, dest, const_cast<T *>(source),
320320
nelems, 0, 0, 0, pe_root, nullptr, nullptr,
321-
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
321+
(intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle, true,
322322
get_status_flag(), is_default_ctx, ROCSHMEM_SUM,
323323
GetROType<T>::Type);
324324

@@ -337,7 +337,7 @@ __device__ void ROContext::alltoall(rocshmem_team_t team, T *dest,
337337

338338
build_queue_element(RO_NET_ALLTOALL, dest, const_cast<T *>(source), nelems, 0,
339339
0, 0, 0, team_obj->ata_buffer, nullptr,
340-
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
340+
(intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle, true,
341341
get_status_flag(), is_default_ctx, ROCSHMEM_SUM,
342342
GetROType<T>::Type);
343343

@@ -356,7 +356,7 @@ __device__ void ROContext::fcollect(rocshmem_team_t team, T *dest,
356356

357357
build_queue_element(RO_NET_FCOLLECT, dest, const_cast<T *>(source), nelems, 0,
358358
0, 0, 0, team_obj->ata_buffer, nullptr,
359-
team_obj->mpi_comm, ro_net_win_id, block_handle, true,
359+
(intptr_t)team_obj->mpi_comm, ro_net_win_id, block_handle, true,
360360
get_status_flag(), is_default_ctx, ROCSHMEM_SUM,
361361
GetROType<T>::Type);
362362

0 commit comments

Comments
 (0)