Skip to content

Commit c88645e

Browse files
Merge branch 'ipc_bringup' into ipc_atomics
[ROCm/rocshmem commit: e58077e]
2 parents e381ea5 + c5be323 commit c88645e

File tree

6 files changed

+253
-99
lines changed

6 files changed

+253
-99
lines changed

projects/rocshmem/src/ipc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ target_sources(
3030
context_ipc_host.cpp
3131
backend_ipc.cpp
3232
ipc_team.cpp
33+
context_ipc_device_coll.cpp
3334
)

projects/rocshmem/src/ipc/context_ipc_device.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ __host__ IPCContext::IPCContext(Backend *b)
4444

4545
auto *bp{backend->ipc_backend_proxy.get()};
4646

47+
barrier_sync = backend->barrier_sync;
4748
g_ret = bp->g_ret;
4849
atomic_base_ptr = bp->atomic_ret->atomic_base_ptr;
4950
}
@@ -97,18 +98,6 @@ __device__ void *IPCContext::shmem_ptr(const void *dest, int pe) {
9798
return ret;
9899
}
99100

100-
__device__ void IPCContext::barrier_all() {
101-
__syncthreads();
102-
}
103-
104-
__device__ void IPCContext::sync_all() {
105-
__syncthreads();
106-
}
107-
108-
__device__ void IPCContext::sync(roc_shmem_team_t team) {
109-
__syncthreads();
110-
}
111-
112101
__device__ void IPCContext::putmem_wg(void *dest, const void *source,
113102
size_t nelems, int pe) {
114103
uint64_t L_offset =

projects/rocshmem/src/ipc/context_ipc_device.hpp

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -194,50 +194,44 @@ class IPCContext : public Context {
194194
template <typename T>
195195
__device__ void get_nbi_wave(T *dest, const T *source, size_t nelems, int pe);
196196

197-
// Wait / Test functions
198-
template <typename T>
199-
__device__ void wait_until(T* ptr, roc_shmem_cmps cmp, T val);
197+
private:
200198

201-
template <typename T>
202-
__device__ void wait_until_all(T* ptr, size_t nelems,
203-
const int *status,
204-
roc_shmem_cmps cmp, T val);
199+
//context class has IpcImpl object (ipcImpl_)
200+
IpcImpl *ipcImpl{nullptr};
205201

206-
template <typename T>
207-
__device__ size_t wait_until_any(T* ptr, size_t nelems,
208-
const int *status,
209-
roc_shmem_cmps cmp, T val);
202+
uint64_t* atomic_base_ptr{nullptr};
210203

211-
template <typename T>
212-
__device__ size_t wait_until_some(T* ptr, size_t nelems,
213-
size_t* indices,
214-
const int *status,
215-
roc_shmem_cmps cmp, T val);
204+
char* g_ret;
205+
206+
//internal functions used by collective operations
207+
template <typename T>
208+
__device__ void internal_put_broadcast(T *dst, const T *src, int nelems,
209+
int pe_root, int PE_start,
210+
int logPE_stride, int PE_size); // NOLINT(runtime/int)
216211

217212
template <typename T>
218-
__device__ void wait_until_all_vector(T* ptr, size_t nelems,
219-
const int *status,
220-
roc_shmem_cmps cmp, T* vals);
213+
__device__ void internal_get_broadcast(T *dst, const T *src, int nelems,
214+
int pe_root); // NOLINT(runtime/int)
221215

222216
template <typename T>
223-
__device__ size_t wait_until_any_vector(T* ptr, size_t nelems,
224-
const int *status,
225-
roc_shmem_cmps cmp, T* vals);
226-
template <typename T>
227-
__device__ size_t wait_until_some_vector(T* ptr, size_t nelems,
228-
size_t* indices,
229-
const int *status,
230-
roc_shmem_cmps cmp, T* vals);
217+
__device__ void fcollect_linear(roc_shmem_team_t team, T *dest,
218+
const T *source, int nelems);
231219

232220
template <typename T>
233-
__device__ int test(T* ptr, roc_shmem_cmps cmp, T val);
234-
235-
private:
221+
__device__ void alltoall_linear(roc_shmem_team_t team, T *dest,
222+
const T *source, int nelems);
223+
224+
__device__ void internal_sync(int pe, int PE_start, int stride, int PE_size,
225+
int64_t *pSync);
236226

237-
uint64_t* atomic_base_ptr{nullptr};
227+
__device__ void internal_direct_barrier(int pe, int PE_start, int stride,
228+
int n_pes, int64_t *pSync);
238229

239-
char* g_ret;
230+
__device__ void internal_atomic_barrier(int pe, int PE_start, int stride,
231+
int n_pes, int64_t *pSync);
240232

233+
//Temporary scratchpad memory used by internal barrier algorithms.
234+
int64_t *barrier_sync{nullptr};
241235
};
242236

243237
} // namespace rocshmem
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a copy
5+
* of this software and associated documentation files (the "Software"), to
6+
* deal in the Software without restriction, including without limitation the
7+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8+
* sell copies of the Software, and to permit persons to whom the Software is
9+
* furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19+
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20+
* IN THE SOFTWARE.
21+
*****************************************************************************/
22+
23+
#include "roc_shmem/roc_shmem.hpp"
24+
#include "../context_incl.hpp"
25+
#include "context_ipc_tmpl_device.hpp"
26+
#include "../util.hpp"
27+
#include "ipc_team.hpp"
28+
29+
namespace rocshmem {
30+
31+
__device__ void IPCContext::internal_direct_barrier(int pe, int PE_start,
32+
int stride, int n_pes,
33+
int64_t *pSync) {
34+
int64_t flag_val = 1;
35+
if (pe == PE_start) {
36+
// Go through all PE offsets (except current offset = 0)
37+
// and wait until they all reach
38+
for (size_t i = 1; i < n_pes; i++) {
39+
wait_until(&pSync[i], ROC_SHMEM_CMP_EQ, flag_val);
40+
pSync[i] = ROC_SHMEM_SYNC_VALUE;
41+
}
42+
threadfence_system();
43+
44+
// Announce to other PEs that all have reached
45+
for (size_t i = 1, j = PE_start + stride; i < n_pes; ++i, j += stride) {
46+
put_nbi(&pSync[0], &flag_val, 1, j);
47+
}
48+
} else {
49+
// Mark current PE offset as reached
50+
size_t pe_offset = (pe - PE_start) / stride;
51+
put_nbi(&pSync[pe_offset], &flag_val, 1, PE_start);
52+
wait_until(&pSync[0], ROC_SHMEM_CMP_EQ, flag_val);
53+
pSync[0] = ROC_SHMEM_SYNC_VALUE;
54+
threadfence_system();
55+
}
56+
}
57+
58+
__device__ void IPCContext::internal_atomic_barrier(int pe, int PE_start,
59+
int stride, int n_pes,
60+
int64_t *pSync) {
61+
int64_t flag_val = 1;
62+
if (pe == PE_start) {
63+
wait_until(&pSync[0], ROC_SHMEM_CMP_EQ, (int64_t)(n_pes - 1));
64+
pSync[0] = ROC_SHMEM_SYNC_VALUE;
65+
threadfence_system();
66+
67+
for (size_t i = 1, j = PE_start + stride; i < n_pes; ++i, j += stride) {
68+
put_nbi(&pSync[0], &flag_val, 1, j);
69+
}
70+
} else {
71+
amo_add<int64_t>(&pSync[0], flag_val, PE_start);
72+
wait_until(&pSync[0], ROC_SHMEM_CMP_EQ, flag_val);
73+
pSync[0] = ROC_SHMEM_SYNC_VALUE;
74+
threadfence_system();
75+
}
76+
}
77+
78+
// Uses PE values that are relative to world
79+
__device__ void IPCContext::internal_sync(int pe, int PE_start, int stride,
80+
int PE_size, int64_t *pSync) {
81+
__syncthreads();
82+
if (is_thread_zero_in_block()) {
83+
if (PE_size < 64) {
84+
internal_direct_barrier(pe, PE_start, stride, PE_size, pSync);
85+
} else {
86+
internal_atomic_barrier(pe, PE_start, stride, PE_size, pSync);
87+
}
88+
}
89+
__threadfence();
90+
__syncthreads();
91+
}
92+
93+
__device__ void IPCContext::sync(roc_shmem_team_t team) {
94+
IPCTeam *team_obj = reinterpret_cast<IPCTeam *>(team);
95+
96+
/**
97+
* Ensure that the stride is a multiple of 2.
98+
*/
99+
int log_pe_stride = static_cast<int>(team_obj->tinfo_wrt_world->log_stride);
100+
int pe = team_obj->my_pe_in_world;
101+
int pe_start = team_obj->tinfo_wrt_world->pe_start;
102+
int pe_stride = (1 << log_pe_stride);
103+
int pe_size = team_obj->num_pes;
104+
105+
internal_sync(pe, pe_start, pe_stride, pe_size, barrier_sync);
106+
}
107+
108+
__device__ void IPCContext::sync_all() {
109+
internal_sync(my_pe, 0, 1, num_pes, barrier_sync);
110+
}
111+
112+
__device__ void IPCContext::barrier_all() {
113+
if (is_thread_zero_in_block()) {
114+
quiet();
115+
}
116+
sync_all();
117+
__syncthreads();
118+
}
119+
120+
} // namespace rocshmem

0 commit comments

Comments
 (0)