11/* *
22 * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2025. ALL RIGHTS RESERVED.
3+ * Copyright (C) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED.
34 *
45 * See file LICENSE for terms.
56 */
1011#include " uct_device_types.h"
1112
1213#include < uct/api/uct_def.h>
14+ #if HAVE_ROCM
15+ #include < uct/rocm/ipc/rocm_ipc.h>
16+ #else
1317#include < uct/cuda/cuda_ipc/cuda_ipc.cuh>
18+ #endif
1419#include < ucs/sys/device_code.h>
1520
16- #if __has_include(<uct/ib/mlx5/gdaki/gdaki.cuh>) && __has_include(<infiniband/mlx5dv.h>)
21+ #if defined(__NVCC__) && __has_include(<uct/ib/mlx5/gdaki/gdaki.cuh>) && __has_include(<infiniband/mlx5dv.h>)
1722# include < uct/ib/mlx5/gdaki/gdaki.cuh>
1823# define UCT_RC_MLX5_GDA_SUPPORTED 1
1924#else
@@ -24,7 +29,11 @@ union uct_device_completion {
2429#if UCT_RC_MLX5_GDA_SUPPORTED
2530 uct_rc_gda_completion_t rc_gda;
2631#endif
32+ #if HAVE_ROCM
33+ uct_rocm_ipc_completion_t rocm_ipc;
34+ #else
2735 uct_cuda_ipc_completion_t cuda_ipc;
36+ #endif
2837};
2938
3039
@@ -73,12 +82,20 @@ uct_device_ep_put(uct_device_ep_h device_ep,
7382 channel_id, flags, comp);
7483 } else
7584#endif
85+ #if HAVE_ROCM
86+ if (device_ep->uct_tl_id == UCT_DEVICE_TL_ROCM_IPC) {
87+ return uct_rocm_ipc_ep_put<level>(device_ep, mem_elem, address,
88+ remote_address, length, flags, comp);
89+ } else
90+ #else
7691 if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
7792 return uct_cuda_ipc_ep_put<level>(device_ep, mem_elem, address,
7893 remote_address, length, flags, comp);
94+ } else
95+ #endif
96+ {
97+ return UCS_ERR_UNSUPPORTED;
7998 }
80-
81- return UCS_ERR_UNSUPPORTED;
8299}
83100
84101
@@ -122,12 +139,20 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_atomic_add(
122139 channel_id, flags, comp);
123140 } else
124141#endif
142+ #if HAVE_ROCM
143+ if (device_ep->uct_tl_id == UCT_DEVICE_TL_ROCM_IPC) {
144+ return uct_rocm_ipc_ep_atomic_add<level>(device_ep, mem_elem, inc_value,
145+ remote_address, flags, comp);
146+ } else
147+ #else
125148 if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
126149 return uct_cuda_ipc_ep_atomic_add<level>(device_ep, mem_elem, inc_value,
127150 remote_address, flags, comp);
151+ } else
152+ #endif
153+ {
154+ return UCS_ERR_UNSUPPORTED;
128155 }
129-
130- return UCS_ERR_UNSUPPORTED;
131156}
132157
133158
@@ -149,14 +174,20 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_get_ptr(
149174 uct_device_ep_h device_ep, const uct_device_mem_element_t *mem_elem,
150175 uint64_t address, void **addr_p)
151176{
152- if (device_ep->uct_tl_id != UCT_DEVICE_TL_CUDA_IPC) {
177+ #if HAVE_ROCM
178+ if (device_ep->uct_tl_id == UCT_DEVICE_TL_ROCM_IPC) {
179+ return uct_rocm_ipc_ep_get_ptr (device_ep, mem_elem, address, addr_p);
180+ } else
181+ #else
182+ if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) {
183+ return uct_cuda_ipc_ep_get_ptr (device_ep, mem_elem, address, addr_p);
184+ } else
185+ #endif
186+ {
153187 return UCS_ERR_UNSUPPORTED;
154188 }
155-
156- return uct_cuda_ipc_ep_get_ptr (device_ep, mem_elem, address, addr_p);
157189}
158190
159-
160191/* *
161192 * @ingroup UCT_DEVICE
162193 * @brief Progress all operations on device endpoint @a device_ep.
0 commit comments