Skip to content

Commit a77375f

Browse files
Merge pull request #21 from avinashkethineedi/ipc_atomics
IPC atomics [ROCm/rocshmem commit: fc45d7a]
2 parents c5be323 + c88645e commit a77375f

File tree

2 files changed

+41
-37
lines changed

2 files changed

+41
-37
lines changed

projects/rocshmem/src/ipc/context_ipc_device.cpp

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ namespace rocshmem {
3939
__host__ IPCContext::IPCContext(Backend *b)
4040
: Context(b, false) {
4141
IPCBackend *backend{static_cast<IPCBackend *>(b)};
42-
ipcImpl = &backend->ipcImpl;
42+
ipcImpl_.ipc_bases = b->ipcImpl.ipc_bases;
43+
ipcImpl_.shm_size = b->ipcImpl.shm_size;
4344

4445
auto *bp{backend->ipc_backend_proxy.get()};
4546

@@ -59,22 +60,18 @@ __device__ void IPCContext::ctx_destroy(){
5960

6061
__device__ void IPCContext::putmem(void *dest, const void *source, size_t nelems,
6162
int pe) {
62-
// TODO (Avinash) check if PE is available for IPC using (isIpcAvailable)
63-
int local_pe = pe % ipcImpl->shm_size;
6463
uint64_t L_offset =
65-
reinterpret_cast<char *>(dest) - ipcImpl->ipc_bases[my_pe];
66-
ipcImpl->ipcCopy(ipcImpl->ipc_bases[local_pe] + L_offset,
64+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
65+
ipcImpl_.ipcCopy(ipcImpl_.ipc_bases[pe] + L_offset,
6766
const_cast<void *>(source), nelems);
6867
}
6968

7069
__device__ void IPCContext::getmem(void *dest, const void *source, size_t nelems,
7170
int pe) {
72-
// TODO (Avinash) check if PE is available for IPC using (isIpcAvailable)
73-
int local_pe = pe % ipcImpl->shm_size;
7471
const char *src_typed = reinterpret_cast<const char *>(source);
7572
uint64_t L_offset =
76-
const_cast<char *>(src_typed) - ipcImpl->ipc_bases[my_pe];
77-
ipcImpl->ipcCopy(dest, ipcImpl->ipc_bases[local_pe] + L_offset, nelems);
73+
const_cast<char *>(src_typed) - ipcImpl_.ipc_bases[my_pe];
74+
ipcImpl_.ipcCopy(dest, ipcImpl_.ipc_bases[pe] + L_offset, nelems);
7875
}
7976

8077
__device__ void IPCContext::putmem_nbi(void *dest, const void *source,
@@ -103,23 +100,19 @@ __device__ void *IPCContext::shmem_ptr(const void *dest, int pe) {
103100

104101
__device__ void IPCContext::putmem_wg(void *dest, const void *source,
105102
size_t nelems, int pe) {
106-
// TODO (Avinash) check if PE is available for IPC using (isIpcAvailable)
107-
int local_pe = pe % ipcImpl->shm_size;
108103
uint64_t L_offset =
109-
reinterpret_cast<char *>(dest) - ipcImpl->ipc_bases[my_pe];
110-
ipcImpl->ipcCopy_wg(ipcImpl->ipc_bases[local_pe] + L_offset,
104+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
105+
ipcImpl_.ipcCopy_wg(ipcImpl_.ipc_bases[pe] + L_offset,
111106
const_cast<void *>(source), nelems);
112107
__syncthreads();
113108
}
114109

115110
__device__ void IPCContext::getmem_wg(void *dest, const void *source,
116111
size_t nelems, int pe) {
117-
// TODO (Avinash) check if PE is available for IPC using (isIpcAvailable)
118-
int local_pe = pe % ipcImpl->shm_size;
119112
const char *src_typed = reinterpret_cast<const char *>(source);
120113
uint64_t L_offset =
121-
const_cast<char *>(src_typed) - ipcImpl->ipc_bases[my_pe];
122-
ipcImpl->ipcCopy_wg(dest, ipcImpl->ipc_bases[local_pe] + L_offset, nelems);
114+
const_cast<char *>(src_typed) - ipcImpl_.ipc_bases[my_pe];
115+
ipcImpl_.ipcCopy_wg(dest, ipcImpl_.ipc_bases[pe] + L_offset, nelems);
123116
__syncthreads();
124117
}
125118

@@ -135,22 +128,18 @@ __device__ void IPCContext::getmem_nbi_wg(void *dest, const void *source,
135128

136129
__device__ void IPCContext::putmem_wave(void *dest, const void *source,
137130
size_t nelems, int pe) {
138-
// TODO (Avinash) check if PE is available for IPC using (isIpcAvailable)
139-
int local_pe = pe % ipcImpl->shm_size;
140131
uint64_t L_offset =
141-
reinterpret_cast<char *>(dest) - ipcImpl->ipc_bases[my_pe];
142-
ipcImpl->ipcCopy_wave(ipcImpl->ipc_bases[local_pe] + L_offset,
132+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
133+
ipcImpl_.ipcCopy_wave(ipcImpl_.ipc_bases[pe] + L_offset,
143134
const_cast<void *>(source), nelems);
144135
}
145136

146137
__device__ void IPCContext::getmem_wave(void *dest, const void *source,
147138
size_t nelems, int pe) {
148-
// TODO (Avinash) check if PE is available for IPC using (isIpcAvailable)
149-
int local_pe = pe % ipcImpl->shm_size;
150139
const char *src_typed = reinterpret_cast<const char *>(source);
151140
uint64_t L_offset =
152-
const_cast<char *>(src_typed) - ipcImpl->ipc_bases[my_pe];
153-
ipcImpl->ipcCopy_wave(dest, ipcImpl->ipc_bases[local_pe] + L_offset,
141+
const_cast<char *>(src_typed) - ipcImpl_.ipc_bases[my_pe];
142+
ipcImpl_.ipcCopy_wave(dest, ipcImpl_.ipc_bases[pe] + L_offset,
154143
nelems);
155144
}
156145

projects/rocshmem/src/ipc/context_ipc_tmpl_device.hpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,19 @@ __device__ void IPCContext::get_nbi(T *dest, const T *source, size_t nelems,
7171

7272
// Atomics
7373
template <typename T>
74-
__device__ void IPCContext::amo_add(void *dst, T value, int pe) {
75-
assert(false);
74+
__device__ void IPCContext::amo_add(void *dest, T value, int pe) {
75+
uint64_t L_offset =
76+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
77+
ipcImpl_.ipcAMOAdd(
78+
reinterpret_cast<T *>(ipcImpl_.ipc_bases[pe] + L_offset), value);
7679
}
7780

7881
template <typename T>
79-
__device__ void IPCContext::amo_set(void *dst, T value, int pe) {
80-
assert(false);
82+
__device__ void IPCContext::amo_set(void *dest, T value, int pe) {
83+
uint64_t L_offset =
84+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
85+
ipcImpl_.ipcAMOSet(
86+
reinterpret_cast<T *>(ipcImpl_.ipc_bases[pe] + L_offset), value);
8187
}
8288

8389
template <typename T>
@@ -120,20 +126,29 @@ __device__ void IPCContext::amo_xor(void *dst, T value, int pe) {
120126
}
121127

122128
template <typename T>
123-
__device__ void IPCContext::amo_cas(void *dst, T value, T cond, int pe) {
124-
assert(false);
129+
__device__ void IPCContext::amo_cas(void *dest, T value, T cond, int pe) {
130+
uint64_t L_offset =
131+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
132+
ipcImpl_.ipcAMOCas(
133+
reinterpret_cast<T *>(ipcImpl_.ipc_bases[pe] + L_offset), cond,
134+
value);
125135
}
126136

127137
template <typename T>
128-
__device__ T IPCContext::amo_fetch_add(void *dst, T value, int pe) {
129-
assert(false);
130-
return 0;
138+
__device__ T IPCContext::amo_fetch_add(void *dest, T value, int pe) {
139+
uint64_t L_offset =
140+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
141+
return ipcImpl_.ipcAMOFetchAdd(
142+
reinterpret_cast<T *>(ipcImpl_.ipc_bases[pe] + L_offset), value);
131143
}
132144

133145
template <typename T>
134-
__device__ T IPCContext::amo_fetch_cas(void *dst, T value, T cond, int pe) {
135-
assert(false);
136-
return 0;
146+
__device__ T IPCContext::amo_fetch_cas(void *dest, T value, T cond, int pe) {
147+
uint64_t L_offset =
148+
reinterpret_cast<char *>(dest) - ipcImpl_.ipc_bases[my_pe];
149+
return ipcImpl_.ipcAMOFetchCas(
150+
reinterpret_cast<T *>(ipcImpl_.ipc_bases[pe] + L_offset), cond,
151+
value);
137152
}
138153

139154
// Collectives

0 commit comments

Comments
 (0)