diff --git a/config/hip.am b/config/hip.am new file mode 100644 index 00000000000..6fa8eda168a --- /dev/null +++ b/config/hip.am @@ -0,0 +1,39 @@ +# +# Copyright (c) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. +# See file LICENSE for terms. +# + +SUFFIXES = .hip + +HIPCC ?= hipcc + +HIPCC_CMD = $(HIPCC) -DHAVE_CONFIG_H -DUCT_DEVICE_CODE_HIP -fPIE -I$(top_srcdir)/src -I$(top_builddir)/src $(BASE_CXXFLAGS) $(CXXFLAGS) $(HIP_CPPFLAGS) $(HIP_CXXFLAGS) $(HIPCC_EXTRA_FLAGS) -Wno-c++20-extensions -c $< -MT $@ -MF $(DEPDIR)/hip/$@.d -MMD -o $@ +HIPCC_LT_CMD = $(LIBTOOL) --tag=CXX --mode=compile $(HIPCC_CMD) + +define hipcc-build + @$(MKDIR_P) $(shell dirname $(DEPDIR)/hip/$@) + @$(if $(filter false,$(AM_V_P)),echo " HIPCC $@") + @$(if $(filter .o,$(suffix $@)),$(HIPCC_CMD),$(HIPCC_LT_CMD)) $($(1)) $(if $(filter false,$(AM_V_P)), >/dev/null) +endef + +define hipcc-source + EXTRA_DIST += $(2) +$(1): $(2) + $$(call hipcc-build,$(3)) +endef + +# Default rules when no target-specific compile flags are required +.hip.o: + $(call hipcc-build) + +.hip.lo: + $(call hipcc-build) + +HIP_DEP_FILES := $(shell find $(DEPDIR)/hip/ -type f -name *.d 2>/dev/null) +-include $(HIP_DEP_FILES) + +clean-local: + -rm -rf $(DEPDIR)/hip + +distclean-local: + -rm -rf $(DEPDIR)/hip diff --git a/src/ucp/core/ucp_device.c b/src/ucp/core/ucp_device.c index 0fedc394803..88c43f02738 100644 --- a/src/ucp/core/ucp_device.c +++ b/src/ucp/core/ucp_device.c @@ -1,5 +1,6 @@ /** * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2025. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. * * See file LICENSE for terms. */ @@ -434,7 +435,11 @@ ucs_status_t ucp_device_local_mem_list_create(const ucp_device_mem_list_params_t *params, ucp_device_local_mem_list_h *mem_list_h) { +#if HAVE_ROCM + const ucs_memory_type_t export_mem_type = UCS_MEMORY_TYPE_ROCM; +#else const ucs_memory_type_t export_mem_type = UCS_MEMORY_TYPE_CUDA; +#endif ucs_status_t status; uct_allocated_memory_t mem; ucs_sys_device_t local_sys_dev; @@ -682,7 +687,11 @@ ucs_status_t ucp_device_remote_mem_list_create(const ucp_device_mem_list_params_t *params, ucp_device_remote_mem_list_h *mem_list_h) { +#if HAVE_ROCM + const ucs_memory_type_t export_mem_type = UCS_MEMORY_TYPE_ROCM; +#else const ucs_memory_type_t export_mem_type = UCS_MEMORY_TYPE_CUDA; +#endif ucs_status_t status; uct_allocated_memory_t mem; diff --git a/src/ucp/wireup/select.c b/src/ucp/wireup/select.c index 1111ff8317b..f4e02a8918f 100644 --- a/src/ucp/wireup/select.c +++ b/src/ucp/wireup/select.c @@ -1,6 +1,7 @@ /** * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2001-2016. ALL RIGHTS RESERVED. * Copyright (C) Los Alamos National Security, LLC. 2019 ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. * * See file LICENSE for terms. */ @@ -2504,6 +2505,14 @@ ucp_wireup_add_device_lanes(const ucp_wireup_select_params_t *select_params, found_lane = ucp_wireup_add_bw_lanes(select_params, &bw_info, mem_type_tl_bitmap, UCP_NULL_LANE, select_ctx, 0); + + /* Add device lanes for ROCm memory */ + ucp_wireup_memaccess_bitmap(context, UCS_MEMORY_TYPE_ROCM, + &mem_type_tl_bitmap); + found_lane |= ucp_wireup_add_bw_lanes(select_params, &bw_info, + mem_type_tl_bitmap, UCP_NULL_LANE, + select_ctx, 0); + if (!found_lane) { ucs_debug("ep %p: could not find device lanes", select_params->ep); } diff --git a/src/ucs/sys/device_code.h b/src/ucs/sys/device_code.h index 75622d0cee8..0bd905c8caa 100644 --- a/src/ucs/sys/device_code.h +++ b/src/ucs/sys/device_code.h @@ -1,5 +1,6 @@ /** * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2025. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. * * See file LICENSE for terms. */ @@ -13,11 +14,11 @@ /* * Declare GPU specific functions */ -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIPCC__) #define UCS_F_DEVICE __device__ __forceinline__ static #else #define UCS_F_DEVICE static inline -#endif /* __NVCC__ */ +#endif /* __NVCC__ || __HIPCC__ */ #ifndef UCP_DEVICE_ENABLE_PARAMS_CHECK diff --git a/src/uct/api/device/uct_device_impl.h b/src/uct/api/device/uct_device_impl.h index ff22502c490..002e396d430 100644 --- a/src/uct/api/device/uct_device_impl.h +++ b/src/uct/api/device/uct_device_impl.h @@ -1,5 +1,6 @@ /** * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2025. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. * * See file LICENSE for terms. */ @@ -10,10 +11,14 @@ #include "uct_device_types.h" #include +#if HAVE_ROCM +#include +#else #include +#endif #include -#if __has_include() && __has_include() +#if defined(__NVCC__) && __has_include() && __has_include() # include # define UCT_RC_MLX5_GDA_SUPPORTED 1 #else @@ -24,7 +29,11 @@ union uct_device_completion { #if UCT_RC_MLX5_GDA_SUPPORTED uct_rc_gda_completion_t rc_gda; #endif +#if HAVE_ROCM + uct_rocm_ipc_completion_t rocm_ipc; +#else uct_cuda_ipc_completion_t cuda_ipc; +#endif }; @@ -73,12 +82,20 @@ uct_device_ep_put(uct_device_ep_h device_ep, channel_id, flags, comp); } else #endif +#if HAVE_ROCM + if (device_ep->uct_tl_id == UCT_DEVICE_TL_ROCM_IPC) { + return uct_rocm_ipc_ep_put(device_ep, mem_elem, address, + remote_address, length, flags, comp); + } else +#else if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) { return uct_cuda_ipc_ep_put(device_ep, mem_elem, address, remote_address, length, flags, comp); + } else +#endif + { + return UCS_ERR_UNSUPPORTED; } - - return UCS_ERR_UNSUPPORTED; } @@ -122,12 +139,20 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_atomic_add( channel_id, flags, comp); } else #endif +#if HAVE_ROCM + if (device_ep->uct_tl_id == UCT_DEVICE_TL_ROCM_IPC) { + return uct_rocm_ipc_ep_atomic_add(device_ep, mem_elem, inc_value, + remote_address, flags, comp); + } else +#else if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) { return uct_cuda_ipc_ep_atomic_add(device_ep, mem_elem, inc_value, remote_address, flags, comp); + } else +#endif + { + return UCS_ERR_UNSUPPORTED; } - - return UCS_ERR_UNSUPPORTED; } @@ -149,14 +174,20 @@ UCS_F_DEVICE ucs_status_t uct_device_ep_get_ptr( uct_device_ep_h device_ep, const uct_device_mem_element_t *mem_elem, uint64_t address, void **addr_p) { - if (device_ep->uct_tl_id != UCT_DEVICE_TL_CUDA_IPC) { +#if HAVE_ROCM + if (device_ep->uct_tl_id == UCT_DEVICE_TL_ROCM_IPC) { + return uct_rocm_ipc_ep_get_ptr(device_ep, mem_elem, address, addr_p); + } else +#else + if (device_ep->uct_tl_id == UCT_DEVICE_TL_CUDA_IPC) { + return uct_cuda_ipc_ep_get_ptr(device_ep, mem_elem, address, addr_p); + } else +#endif + { return UCS_ERR_UNSUPPORTED; } - - return uct_cuda_ipc_ep_get_ptr(device_ep, mem_elem, address, addr_p); } - /** * @ingroup UCT_DEVICE * @brief Progress all operations on device endpoint @a device_ep. diff --git a/src/uct/api/device/uct_device_types.h b/src/uct/api/device/uct_device_types.h index c2104bd07fc..9bd2f587f96 100644 --- a/src/uct/api/device/uct_device_types.h +++ b/src/uct/api/device/uct_device_types.h @@ -1,6 +1,7 @@ /** * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2025. ALL RIGHTS RESERVED. - * + * Copyright (C) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. + * * See file LICENSE for terms. */ @@ -36,6 +37,21 @@ typedef struct { } uct_cuda_ipc_completion_t; +/** + * @brief Device memory element for ROCm IPC. + */ +typedef struct { + ptrdiff_t mapped_offset; +} uct_rocm_ipc_device_mem_element_t; + + +/** + * @brief Completion object for device ROCm IPC operations. + */ +typedef struct { +} uct_rocm_ipc_completion_t; + + /** * @brief Device memory element for GDAKI. */ @@ -57,6 +73,7 @@ typedef enum { typedef enum { UCT_DEVICE_TL_RC_MLX5_GDA, UCT_DEVICE_TL_CUDA_IPC, + UCT_DEVICE_TL_ROCM_IPC, UCT_DEVICE_TL_LAST } uct_device_tl_id_t; @@ -75,6 +92,7 @@ typedef union uct_device_completion uct_device_completion_t; union uct_device_mem_element { uct_ib_md_device_mem_element_t ib_md_mem_element; uct_cuda_ipc_md_device_mem_element_t cuda_ipc_md_mem_element; + uct_rocm_ipc_device_mem_element_t rocm_ipc_mem_element; }; diff --git a/src/uct/rocm/Makefile.am b/src/uct/rocm/Makefile.am index 310ca2e83be..e2e83580180 100644 --- a/src/uct/rocm/Makefile.am +++ b/src/uct/rocm/Makefile.am @@ -1,5 +1,6 @@ # # Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2001-2018. ALL RIGHTS RESERVED. +# Copyright (C) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. # See file LICENSE for terms. # @@ -14,6 +15,10 @@ libuct_rocm_la_LDFLAGS = $(ROCM_LDFLAGS) $(ROCM_LIBS) -version-info $(SOVERSION $(patsubst %, -Xlinker %, -L$(ROCM_ROOT)/lib -rpath $(ROCM_ROOT)/hip/lib -rpath $(ROCM_ROOT)/lib) \ $(patsubst %, -Xlinker %, --enable-new-dtags) \ $(patsubst %, -Xlinker %, -rpath $(ROCM_ROOT)/lib64) +libuct_rocm_ladir = $(includedir)/uct/rocm + +nobase_dist_libuct_rocm_la_HEADERS = \ + ipc/rocm_ipc.h noinst_HEADERS = \ base/rocm_base.h \ diff --git a/src/uct/rocm/copy/rocm_copy_md.c b/src/uct/rocm/copy/rocm_copy_md.c index 3cdab6491ea..9954b6ca257 100644 --- a/src/uct/rocm/copy/rocm_copy_md.c +++ b/src/uct/rocm/copy/rocm_copy_md.c @@ -1,5 +1,5 @@ /* - * Copyright (C) Advanced Micro Devices, Inc. 2019-2023. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2019-2026. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ @@ -57,7 +57,8 @@ uct_rocm_copy_md_query(uct_md_h uct_md, uct_md_attr_v2_t *md_attr) md_attr->cache_mem_types = UCS_BIT(UCS_MEMORY_TYPE_HOST) | UCS_BIT(UCS_MEMORY_TYPE_ROCM); md_attr->alloc_mem_types = UCS_BIT(UCS_MEMORY_TYPE_ROCM); - md_attr->access_mem_types = UCS_BIT(UCS_MEMORY_TYPE_ROCM); + md_attr->access_mem_types = UCS_BIT(UCS_MEMORY_TYPE_HOST) | + UCS_BIT(UCS_MEMORY_TYPE_ROCM); md_attr->detect_mem_types = UCS_BIT(UCS_MEMORY_TYPE_ROCM); if (md->have_dmabuf) { md_attr->dmabuf_mem_types |= UCS_BIT(UCS_MEMORY_TYPE_ROCM); diff --git a/src/uct/rocm/ipc/rocm_ipc.h b/src/uct/rocm/ipc/rocm_ipc.h new file mode 100644 index 00000000000..70d92bb3cdc --- /dev/null +++ b/src/uct/rocm/ipc/rocm_ipc.h @@ -0,0 +1,301 @@ +/** + * Copyright (c) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#ifndef UCT_ROCM_IPC_H +#define UCT_ROCM_IPC_H + +#include +#include +#include +#include + +#define UCT_ROCM_IPC_IS_ALIGNED_POW2(_n, _p) (!((_n) & ((_p)-1))) + +/* Dynamically detect wavefront size using compiler builtin. */ +#if __has_builtin(__builtin_amdgcn_wavefrontsize) +#define UCT_ROCM_IPC_WAVEFRONT_SIZE __builtin_amdgcn_wavefrontsize() +#else +#define UCT_ROCM_IPC_WAVEFRONT_SIZE __AMDGCN_WAVEFRONT_SIZE +#endif + +#define UCT_ROCM_IPC_COPY_LOOP_UNROLL 8 + +/* Vectorized load with cache coherency - using direct memory access */ +__device__ static inline int4 uct_rocm_ipc_ld_global_cg(const int4 *p) +{ + return *p; +} + +__device__ static inline void uct_rocm_ipc_st_global_cg(int4 *p, const int4 &v) +{ + *p = v; +} + +__device__ static inline int2 uct_rocm_ipc_ld_global_cg(const int2 *p) +{ + return *p; +} + +__device__ static inline void uct_rocm_ipc_st_global_cg(int2 *p, const int2 &v) +{ + *p = v; +} + +/* Get lane ID and number of lanes based on parallelism level */ +template +__device__ static inline void +uct_rocm_ipc_get_lane(unsigned &lane_id, unsigned &num_lanes) +{ + switch (level) { + case UCS_DEVICE_LEVEL_THREAD: + lane_id = 0; + num_lanes = 1; + break; + case UCS_DEVICE_LEVEL_WARP: + lane_id = threadIdx.x % UCT_ROCM_IPC_WAVEFRONT_SIZE; + num_lanes = UCT_ROCM_IPC_WAVEFRONT_SIZE; + break; + case UCS_DEVICE_LEVEL_BLOCK: + lane_id = threadIdx.x; + num_lanes = blockDim.x; + break; + case UCS_DEVICE_LEVEL_GRID: + lane_id = threadIdx.x + blockIdx.x * blockDim.x; + num_lanes = blockDim.x * gridDim.x; + break; + } +} + +/* Map remote address using IPC handle */ +__device__ static inline void * +uct_rocm_ipc_map_remote(const uct_rocm_ipc_device_mem_element_t *elem, + uint64_t remote_address) +{ + return reinterpret_cast((uintptr_t)remote_address + + elem->mapped_offset); +} + +/* System-wide atomic increment */ +__device__ static inline void +uct_rocm_ipc_atomic_inc(uint64_t *dst, uint64_t inc_value) +{ + atomicAdd_system((unsigned long long*)dst, (unsigned long long)inc_value); + __threadfence_system(); +} + +/* Level-appropriate synchronization */ +template +__device__ static inline void uct_rocm_ipc_level_sync() +{ + switch (level) { + case UCS_DEVICE_LEVEL_THREAD: + break; + case UCS_DEVICE_LEVEL_WARP: + case UCS_DEVICE_LEVEL_BLOCK: + __syncthreads(); + break; + case UCS_DEVICE_LEVEL_GRID: + /* Not implemented */ + break; + } +} + +/* Copy routines for different parallelism levels */ +template +__device__ void uct_rocm_ipc_copy_level(void *dst, const void *src, size_t len); + +/* Thread-level copy */ +template<> +__device__ inline void +uct_rocm_ipc_copy_level(void *dst, const void *src, + size_t len) +{ + memcpy(dst, src, len); +} + +/* Wavefront-level copy (64 threads) */ +template<> +__device__ inline void +uct_rocm_ipc_copy_level(void *dst, const void *src, + size_t len) +{ + using vec4 = int4; + using vec2 = int2; + unsigned int lane_id, num_lanes; + + uct_rocm_ipc_get_lane(lane_id, num_lanes); + auto s1 = reinterpret_cast(src); + auto d1 = reinterpret_cast(dst); + + /* 16B-aligned fast path using vec4 */ + if (UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)s1, sizeof(vec4)) && + UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)d1, sizeof(vec4))) { + const vec4 *s4 = reinterpret_cast(s1); + vec4 *d4 = reinterpret_cast(d1); + size_t n4 = len / sizeof(vec4); + + for (size_t i = lane_id; i < n4; i += num_lanes) { + vec4 v = uct_rocm_ipc_ld_global_cg(s4 + i); + uct_rocm_ipc_st_global_cg(d4 + i, v); + } + + len = len - n4 * sizeof(vec4); + if (len == 0) { + return; + } + + s1 = reinterpret_cast(s4 + n4); + d1 = reinterpret_cast(d4 + n4); + } + + /* 8B-aligned fast path using vec2 */ + if (UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)s1, sizeof(vec2)) && + UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)d1, sizeof(vec2))) { + const vec2 *s2 = reinterpret_cast(s1); + vec2 *d2 = reinterpret_cast(d1); + size_t n2 = len / sizeof(vec2); + + for (size_t i = lane_id; i < n2; i += num_lanes) { + vec2 v2 = uct_rocm_ipc_ld_global_cg(s2 + i); + uct_rocm_ipc_st_global_cg(d2 + i, v2); + } + + len = len - n2 * sizeof(vec2); + if (len == 0) { + return; + } + + s1 = reinterpret_cast(s2 + n2); + d1 = reinterpret_cast(d2 + n2); + } + + /* Byte tail */ + for (size_t i = lane_id; i < len; i += num_lanes) { + d1[i] = s1[i]; + } +} + +template<> +__device__ inline void +uct_rocm_ipc_copy_level(void *dst, const void *src, + size_t len) +{ + using vec4 = int4; + using vec2 = int2; + auto s1 = reinterpret_cast(src); + auto d1 = reinterpret_cast(dst); + + if (UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)s1, sizeof(vec4)) && + UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)d1, sizeof(vec4))) { + const vec4 *s4 = reinterpret_cast(s1); + vec4 *d4 = reinterpret_cast(d1); + size_t num_lines = len / sizeof(vec4); + + for (size_t line = threadIdx.x; line < num_lines; line += blockDim.x) { + vec4 v = uct_rocm_ipc_ld_global_cg(s4 + line); + uct_rocm_ipc_st_global_cg(d4 + line, v); + } + + len = len - num_lines * sizeof(vec4); + if (len == 0) { + return; + } + + s1 = reinterpret_cast(s4 + num_lines); + d1 = reinterpret_cast(d4 + num_lines); + } + + /* 8B-aligned fast path using vec2 */ + if (UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)s1, sizeof(vec2)) && + UCT_ROCM_IPC_IS_ALIGNED_POW2((intptr_t)d1, sizeof(vec2))) { + const vec2 *s2 = reinterpret_cast(s1); + vec2 *d2 = reinterpret_cast(d1); + size_t num_lines = len / sizeof(vec2); + + for (size_t line = threadIdx.x; line < num_lines; line += blockDim.x) { + vec2 v2 = uct_rocm_ipc_ld_global_cg(s2 + line); + uct_rocm_ipc_st_global_cg(d2 + line, v2); + } + + len = len - num_lines * sizeof(vec2); + if (len == 0) { + return; + } + + s1 = reinterpret_cast(s2 + num_lines); + d1 = reinterpret_cast(d2 + num_lines); + } + + /* Byte tail */ + for (size_t line = threadIdx.x; line < len; line += blockDim.x) { + d1[line] = s1[line]; + } +} + +/* Grid-level copy - not implemented */ +template<> +__device__ inline void +uct_rocm_ipc_copy_level(void *dst, const void *src, + size_t len) +{ + /* Not implemented */ +} + +template +__device__ ucs_status_t uct_rocm_ipc_ep_put( + uct_device_ep_h device_ep, const uct_device_mem_element_t *mem_elem, + const void *address, uint64_t remote_address, size_t length, + uint64_t flags, uct_device_completion_t *comp) +{ + auto rocm_ipc_mem_element = + reinterpret_cast( + mem_elem); + void *mapped_rem_addr; + + mapped_rem_addr = uct_rocm_ipc_map_remote(rocm_ipc_mem_element, + remote_address); + uct_rocm_ipc_copy_level(mapped_rem_addr, address, length); + uct_rocm_ipc_level_sync(); + + return UCS_OK; +} + +/* Atomic add operation */ +template +__device__ ucs_status_t uct_rocm_ipc_ep_atomic_add( + uct_device_ep_h device_ep, const uct_device_mem_element_t *mem_elem, + uint64_t inc_value, uint64_t remote_address, uint64_t flags, + uct_device_completion_t *comp) +{ + auto rocm_ipc_mem_element = + reinterpret_cast( + mem_elem); + uint64_t *mapped_rem_addr; + unsigned int lane_id, num_lanes; + + uct_rocm_ipc_get_lane(lane_id, num_lanes); + if (lane_id == 0) { + mapped_rem_addr = reinterpret_cast( + uct_rocm_ipc_map_remote(rocm_ipc_mem_element, remote_address)); + uct_rocm_ipc_atomic_inc(mapped_rem_addr, inc_value); + } + + uct_rocm_ipc_level_sync(); + return UCS_OK; +} + +__device__ static inline ucs_status_t +uct_rocm_ipc_ep_get_ptr(uct_device_ep_h device_ep, + const uct_device_mem_element_t *mem_elem, + uint64_t address, void **addr_p) +{ + auto rocm_ipc_mem_element = + reinterpret_cast( + mem_elem); + *addr_p = uct_rocm_ipc_map_remote(rocm_ipc_mem_element, address); + return UCS_OK; +} + +#endif /* UCT_ROCM_IPC_H */ diff --git a/src/uct/rocm/ipc/rocm_ipc_cache.c b/src/uct/rocm/ipc/rocm_ipc_cache.c index 93ccb2e2d84..77361adcdad 100644 --- a/src/uct/rocm/ipc/rocm_ipc_cache.c +++ b/src/uct/rocm/ipc/rocm_ipc_cache.c @@ -233,6 +233,29 @@ ucs_status_t uct_rocm_ipc_create_cache(uct_rocm_ipc_cache_t **cache, return status; } +ucs_status_t uct_rocm_ipc_component_init_cache(void) +{ + ucs_status_t status; + + pthread_mutex_lock(&uct_rocm_ipc_component.lock); + + if (uct_rocm_ipc_component.ipc_cache == NULL) { + status = uct_rocm_ipc_create_cache(&uct_rocm_ipc_component.ipc_cache, + "rocm_ipc_component"); + if (status != UCS_OK) { + ucs_error("Failed to create ROCm IPC component cache: %s", + ucs_status_string(status)); + pthread_mutex_unlock(&uct_rocm_ipc_component.lock); + return status; + } + + ucs_debug("ROCm IPC component cache initialized"); + } + + pthread_mutex_unlock(&uct_rocm_ipc_component.lock); + return UCS_OK; +} + void uct_rocm_ipc_destroy_cache(uct_rocm_ipc_cache_t *cache) { uct_rocm_ipc_cache_purge(cache); diff --git a/src/uct/rocm/ipc/rocm_ipc_cache.h b/src/uct/rocm/ipc/rocm_ipc_cache.h index 7982d743c8e..6852828bea9 100644 --- a/src/uct/rocm/ipc/rocm_ipc_cache.h +++ b/src/uct/rocm/ipc/rocm_ipc_cache.h @@ -29,6 +29,8 @@ typedef struct uct_rocm_ipc_cache { ucs_status_t uct_rocm_ipc_create_cache(uct_rocm_ipc_cache_t **cache, const char *name); +ucs_status_t uct_rocm_ipc_component_init_cache(void); + void uct_rocm_ipc_destroy_cache(uct_rocm_ipc_cache_t *cache); ucs_status_t uct_rocm_ipc_cache_map_memhandle(void *arg, uct_rocm_ipc_key_t *key, diff --git a/src/uct/rocm/ipc/rocm_ipc_ep.c b/src/uct/rocm/ipc/rocm_ipc_ep.c index d1dbafb5158..797aa669018 100644 --- a/src/uct/rocm/ipc/rocm_ipc_ep.c +++ b/src/uct/rocm/ipc/rocm_ipc_ep.c @@ -1,5 +1,5 @@ /* - * Copyright (C) Advanced Micro Devices, Inc. 2019-2023. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2019-2026. ALL RIGHTS RESERVED. * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ @@ -10,7 +10,6 @@ #include "rocm_ipc_ep.h" #include "rocm_ipc_iface.h" -#include "rocm_ipc_md.h" #include #include @@ -20,27 +19,20 @@ static UCS_CLASS_INIT_FUNC(uct_rocm_ipc_ep_t, const uct_ep_params_t *params) { uct_rocm_ipc_iface_t *iface = ucs_derived_of(params->iface, uct_rocm_ipc_iface_t); - char target_name[64]; - ucs_status_t status; UCS_CLASS_CALL_SUPER_INIT(uct_base_ep_t, &iface->super); self->remote_pid = *(const pid_t*)params->iface_addr; - - snprintf(target_name, sizeof(target_name), "dest:%d", *(pid_t*)params->iface_addr); - status = uct_rocm_ipc_create_cache(&self->remote_memh_cache, target_name); - if (status != UCS_OK) { - ucs_error("could not create create rocm ipc cache: %s", - ucs_status_string(status)); - return status; - } + self->device_ep = NULL; return UCS_OK; } static UCS_CLASS_CLEANUP_FUNC(uct_rocm_ipc_ep_t) { - uct_rocm_ipc_destroy_cache(self->remote_memh_cache); + if (self->device_ep != NULL) { + hsa_amd_memory_pool_free(self->device_ep); + } } UCS_CLASS_DEFINE(uct_rocm_ipc_ep_t, uct_base_ep_t); @@ -58,7 +50,6 @@ ucs_status_t uct_rocm_ipc_ep_zcopy(uct_ep_h tl_ep, uct_completion_t *comp, int is_put) { - uct_rocm_ipc_ep_t *ep = ucs_derived_of(tl_ep, uct_rocm_ipc_ep_t); hsa_status_t status; hsa_agent_t local_agent, remote_agent; hsa_agent_t dst_agent, src_agent; @@ -95,7 +86,13 @@ ucs_status_t uct_rocm_ipc_ep_zcopy(uct_ep_h tl_ep, } if (iface->config.enable_ipc_handle_cache) { - ret = uct_rocm_ipc_cache_map_memhandle((void*)ep->remote_memh_cache, + /* Ensure component cache is initialized */ + ret = uct_rocm_ipc_component_init_cache(); + if (ucs_unlikely(ret != UCS_OK)) { + ucs_error("failed to initialize rocm_ipc component cache\n"); + return ret; + } + ret = uct_rocm_ipc_cache_map_memhandle((void*)uct_rocm_ipc_component.ipc_cache, key, &remote_base_addr); if (ucs_unlikely(ret != UCS_OK)) { ucs_error("fail to attach ipc mem %p %d\n", (void*)key->address, @@ -221,3 +218,52 @@ ucs_status_t uct_rocm_ipc_ep_get_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, siz return ret; } + +ucs_status_t +uct_rocm_ipc_ep_get_device_ep(uct_ep_h tl_ep, uct_device_ep_h *device_ep_p) +{ + uct_rocm_ipc_ep_t *ep = ucs_derived_of(tl_ep, uct_rocm_ipc_ep_t); + uct_device_ep_t device_ep; + ucs_status_t status; + hsa_status_t hsa_status; + hsa_amd_memory_pool_t pool; + + if (ep->device_ep != NULL) { + goto out; + } + device_ep.uct_tl_id = UCT_DEVICE_TL_ROCM_IPC; + + /* Get memory pool for device allocation */ + status = uct_rocm_base_get_last_device_pool(&pool); + if (status != UCS_OK) { + ucs_error("Failed to get ROCm memory pool"); + goto err; + } + /* Allocate device memory for endpoint structure */ + hsa_status = hsa_amd_memory_pool_allocate(pool, sizeof(uct_device_ep_t), 0, + (void**)&ep->device_ep); + + if (hsa_status != HSA_STATUS_SUCCESS) { + ucs_error("Failed to allocate device endpoint memory: %d", hsa_status); + status = UCS_ERR_NO_MEMORY; + goto err; + } + /* Copy endpoint structure to device */ + hsa_status = hsa_memory_copy(ep->device_ep, &device_ep, + sizeof(uct_device_ep_t)); + if (hsa_status != HSA_STATUS_SUCCESS) { + ucs_error("Failed to copy endpoint to device: %d", hsa_status); + status = UCS_ERR_IO_ERROR; + goto err_free_mem; + } + +out: + *device_ep_p = ep->device_ep; + return UCS_OK; + +err_free_mem: + hsa_amd_memory_pool_free(ep->device_ep); + ep->device_ep = NULL; +err: + return status; +} diff --git a/src/uct/rocm/ipc/rocm_ipc_ep.h b/src/uct/rocm/ipc/rocm_ipc_ep.h index 0accb543121..295c71a2540 100644 --- a/src/uct/rocm/ipc/rocm_ipc_ep.h +++ b/src/uct/rocm/ipc/rocm_ipc_ep.h @@ -1,5 +1,5 @@ /* - * Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2019-2026. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ @@ -9,13 +9,14 @@ #include #include #include +#include #include "rocm_ipc_cache.h" typedef struct uct_rocm_ipc_ep { uct_base_ep_t super; pid_t remote_pid; - uct_rocm_ipc_cache_t *remote_memh_cache; + uct_device_ep_h device_ep; } uct_rocm_ipc_ep_t; UCS_CLASS_DECLARE_NEW_FUNC(uct_rocm_ipc_ep_t, uct_ep_t, const uct_ep_params_t *); @@ -27,5 +28,7 @@ ucs_status_t uct_rocm_ipc_ep_put_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, siz ucs_status_t uct_rocm_ipc_ep_get_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt, uint64_t remote_addr, uct_rkey_t rkey, uct_completion_t *comp); +ucs_status_t +uct_rocm_ipc_ep_get_device_ep(uct_ep_h tl_ep, uct_device_ep_h *device_ep_p); #endif diff --git a/src/uct/rocm/ipc/rocm_ipc_iface.c b/src/uct/rocm/ipc/rocm_ipc_iface.c index ddd68f0d712..df0df626762 100644 --- a/src/uct/rocm/ipc/rocm_ipc_iface.c +++ b/src/uct/rocm/ipc/rocm_ipc_iface.c @@ -1,5 +1,5 @@ /* - * Copyright (C) Advanced Micro Devices, Inc. 2019-2023. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2019-2026. ALL RIGHTS RESERVED. * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ @@ -28,7 +28,8 @@ static ucs_config_field_t uct_rocm_ipc_iface_config_table[] = { ucs_offsetof(uct_rocm_ipc_iface_config_t, params.latency), UCS_CONFIG_TYPE_TIME}, - {"CACHE_IPC_HANDLES", "y", "Enable caching IPC handles", + {"CACHE_IPC_HANDLES", "y", + "Enable caching IPC handles (Note: caching is always enabled for device initiated communication)", ucs_offsetof(uct_rocm_ipc_iface_config_t, params.enable_ipc_handle_cache), UCS_CONFIG_TYPE_BOOL}, @@ -121,9 +122,9 @@ static ucs_status_t uct_rocm_ipc_iface_query(uct_iface_h tl_iface, iface_attr->ep_addr_len = 0; iface_attr->max_conn_priv = 0; iface_attr->cap.flags = UCT_IFACE_FLAG_GET_ZCOPY | - UCT_IFACE_FLAG_PUT_ZCOPY | - UCT_IFACE_FLAG_PENDING | - UCT_IFACE_FLAG_CONNECT_TO_IFACE; + UCT_IFACE_FLAG_PUT_ZCOPY | UCT_IFACE_FLAG_PENDING | + UCT_IFACE_FLAG_CONNECT_TO_IFACE | + UCT_IFACE_FLAG_DEVICE_EP; iface_attr->latency = ucs_linear_func_make(iface->config.latency, 0); iface_attr->bandwidth.dedicated = 0; @@ -171,7 +172,7 @@ static uct_iface_internal_ops_t uct_rocm_ipc_iface_internal_ops = { .ep_connect_to_ep_v2 = (uct_ep_connect_to_ep_v2_func_t)ucs_empty_function_return_unsupported, .iface_is_reachable_v2 = uct_rocm_ipc_iface_is_reachable_v2, .ep_is_connected = uct_base_ep_is_connected, - .ep_get_device_ep = (uct_ep_get_device_ep_func_t)ucs_empty_function_return_unsupported + .ep_get_device_ep = uct_rocm_ipc_ep_get_device_ep }; static uct_iface_ops_t uct_rocm_ipc_iface_ops = { @@ -245,6 +246,6 @@ static UCS_CLASS_DEFINE_NEW_FUNC(uct_rocm_ipc_iface_t, uct_iface_t, uct_md_h, const uct_iface_config_t *); static UCS_CLASS_DEFINE_DELETE_FUNC(uct_rocm_ipc_iface_t, uct_iface_t); -UCT_TL_DEFINE(&uct_rocm_ipc_component, rocm_ipc, uct_rocm_base_query_devices, +UCT_TL_DEFINE(&uct_rocm_ipc_component.super, rocm_ipc, uct_rocm_base_query_devices, uct_rocm_ipc_iface_t, "ROCM_IPC_", uct_rocm_ipc_iface_config_table, uct_rocm_ipc_iface_config_t); diff --git a/src/uct/rocm/ipc/rocm_ipc_md.c b/src/uct/rocm/ipc/rocm_ipc_md.c index c9c3f87b045..4b00ad98590 100644 --- a/src/uct/rocm/ipc/rocm_ipc_md.c +++ b/src/uct/rocm/ipc/rocm_ipc_md.c @@ -1,5 +1,5 @@ /* - * Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2019-2026. ALL RIGHTS RESERVED. * Copyright (c) NVIDIA CORPORATION & AFFILIATES, 2020. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ @@ -9,9 +9,11 @@ #endif #include "rocm_ipc_md.h" +#include "rocm_ipc_cache.h" #include #include +#include static ucs_config_field_t uct_rocm_ipc_md_config_table[] = { @@ -26,7 +28,8 @@ static ucs_status_t uct_rocm_ipc_md_query(uct_md_h md, uct_md_attr_v2_t *md_attr { uct_md_base_md_query(md_attr); md_attr->rkey_packed_size = sizeof(uct_rocm_ipc_key_t); - md_attr->flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_NEED_RKEY; + md_attr->flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_NEED_RKEY | + UCT_MD_FLAG_MEMTYPE_COPY; md_attr->reg_mem_types = UCS_BIT(UCS_MEMORY_TYPE_ROCM); md_attr->cache_mem_types = UCS_BIT(UCS_MEMORY_TYPE_ROCM); md_attr->access_mem_types = UCS_BIT(UCS_MEMORY_TYPE_ROCM); @@ -117,6 +120,44 @@ uct_rocm_ipc_mem_dereg(uct_md_h md, return UCS_OK; } +static ucs_status_t +uct_rocm_ipc_md_mem_elem_pack(uct_md_h md_h, uct_mem_h memh, uct_rkey_t rkey, + uct_device_mem_element_t *mem_elem_p) +{ + uct_md_t *md = (uct_md_t*)md_h; + uct_rocm_ipc_component_t *rocm_comp = + ucs_derived_of(md->component, uct_rocm_ipc_component_t); + uct_rocm_ipc_key_t *key = (uct_rocm_ipc_key_t*)rkey; + uct_rocm_ipc_device_mem_element_t *rocm_ipc_mem_element = + (uct_rocm_ipc_device_mem_element_t*)mem_elem_p; + void *mapped_addr; + ucs_status_t status; + + /* Ensure cache is initialized */ + status = uct_rocm_ipc_component_init_cache(); + if (status != UCS_OK) { + return status; + } + + /* Use cache instead of direct attach */ + status = uct_rocm_ipc_cache_map_memhandle(rocm_comp->ipc_cache, key, + &mapped_addr); + if (status != UCS_OK) { + ucs_error("Failed to map IPC handle: %s", ucs_status_string(status)); + return status; + } + + /* Calculate offset: base_address - mapped_address */ + rocm_ipc_mem_element->mapped_offset = UCS_PTR_BYTE_DIFF(key->address, + mapped_addr); + + ucs_trace("rocm_ipc mem_elem_pack: key_addr=%p mapped=%p offset=%ld", + (void*)key->address, mapped_addr, + rocm_ipc_mem_element->mapped_offset); + + return UCS_OK; +} + static ucs_status_t uct_rocm_ipc_md_open(uct_component_h component, const char *md_name, const uct_md_config_t *uct_md_config, uct_md_h *md_p) @@ -124,20 +165,26 @@ uct_rocm_ipc_md_open(uct_component_h component, const char *md_name, static uct_md_ops_t md_ops = { .close = (uct_md_close_func_t)ucs_empty_function, .query = uct_rocm_ipc_md_query, - .mem_alloc = (uct_md_mem_alloc_func_t)ucs_empty_function_return_unsupported, - .mem_free = (uct_md_mem_free_func_t)ucs_empty_function_return_unsupported, - .mem_advise = (uct_md_mem_advise_func_t)ucs_empty_function_return_unsupported, + .mem_alloc = (uct_md_mem_alloc_func_t) + ucs_empty_function_return_unsupported, + .mem_free = (uct_md_mem_free_func_t) + ucs_empty_function_return_unsupported, + .mem_advise = (uct_md_mem_advise_func_t) + ucs_empty_function_return_unsupported, .mem_reg = uct_rocm_ipc_mem_reg, .mem_dereg = uct_rocm_ipc_mem_dereg, - .mem_query = (uct_md_mem_query_func_t)ucs_empty_function_return_unsupported, + .mem_query = (uct_md_mem_query_func_t) + ucs_empty_function_return_unsupported, .mkey_pack = uct_rocm_ipc_mkey_pack, - .mem_attach = (uct_md_mem_attach_func_t)ucs_empty_function_return_unsupported, - .detect_memory_type = (uct_md_detect_memory_type_func_t)ucs_empty_function_return_unsupported, - .mem_elem_pack = (uct_md_mem_elem_pack_func_t)ucs_empty_function_return_unsupported + .mem_elem_pack = uct_rocm_ipc_md_mem_elem_pack, + .mem_attach = (uct_md_mem_attach_func_t) + ucs_empty_function_return_unsupported, + .detect_memory_type = (uct_md_detect_memory_type_func_t) + ucs_empty_function_return_unsupported }; static uct_md_t md = { .ops = &md_ops, - .component = &uct_rocm_ipc_component, + .component = &uct_rocm_ipc_component.super, }; *md_p = &md; @@ -173,25 +220,73 @@ static ucs_status_t uct_rocm_ipc_rkey_release(uct_component_t *component, return UCS_OK; } -uct_component_t uct_rocm_ipc_component = { - .query_md_resources = uct_rocm_base_query_md_resources, - .md_open = uct_rocm_ipc_md_open, - .cm_open = (uct_component_cm_open_func_t)ucs_empty_function_return_unsupported, - .rkey_unpack = uct_rocm_ipc_rkey_unpack, - .rkey_ptr = (uct_component_rkey_ptr_func_t)ucs_empty_function_return_unsupported, - .rkey_release = uct_rocm_ipc_rkey_release, - .rkey_compare = uct_base_rkey_compare, - .name = "rocm_ipc", - .md_config = { - .name = "ROCm-IPC memory domain", - .prefix = "ROCM_IPC_MD_", - .table = uct_rocm_ipc_md_config_table, - .size = sizeof(uct_rocm_ipc_md_config_t), +ucs_status_t uct_rocm_ipc_rkey_ptr(uct_component_t *component, uct_rkey_t rkey, + void *handle, uint64_t raddr, void **laddr_p) +{ + uct_rocm_ipc_component_t *rocm_comp = + ucs_derived_of(component, uct_rocm_ipc_component_t); + uct_rocm_ipc_key_t *key = (uct_rocm_ipc_key_t*)rkey; + void *mapped_addr; + ptrdiff_t offset; + ucs_status_t status; + + /* Ensure cache is initialized */ + status = uct_rocm_ipc_component_init_cache(); + if (status != UCS_OK) { + return status; + } + + /* Use cache instead of direct attach */ + status = uct_rocm_ipc_cache_map_memhandle(rocm_comp->ipc_cache, key, + &mapped_addr); + if (status != UCS_OK) { + ucs_error("Failed to map IPC handle: %s", ucs_status_string(status)); + return status; + } + + /* Calculate offset from base address */ + offset = UCS_PTR_BYTE_DIFF(key->address, raddr); + *laddr_p = UCS_PTR_BYTE_OFFSET(mapped_addr, offset); + + ucs_trace("rocm_ipc rkey_ptr: raddr=%p mapped=%p offset=%ld laddr=%p", + (void*)raddr, mapped_addr, offset, *laddr_p); + + return UCS_OK; +} + +uct_rocm_ipc_component_t uct_rocm_ipc_component = { + .super = { + .query_md_resources = uct_rocm_base_query_md_resources, + .md_open = uct_rocm_ipc_md_open, + .cm_open = (uct_component_cm_open_func_t) + ucs_empty_function_return_unsupported, + .rkey_unpack = uct_rocm_ipc_rkey_unpack, + .rkey_ptr = uct_rocm_ipc_rkey_ptr, + .rkey_release = uct_rocm_ipc_rkey_release, + .rkey_compare = uct_base_rkey_compare, + .name = "rocm_ipc", + .md_config = + { + .name = "ROCm-IPC memory domain", + .prefix = "ROCM_IPC_MD_", + .table = uct_rocm_ipc_md_config_table, + .size = sizeof(uct_rocm_ipc_md_config_t), + }, + .cm_config = UCS_CONFIG_EMPTY_GLOBAL_LIST_ENTRY, + .tl_list = UCT_COMPONENT_TL_LIST_INITIALIZER(&uct_rocm_ipc_component.super), + .flags = 0, + .md_vfs_init = (uct_component_md_vfs_init_func_t)ucs_empty_function }, - .cm_config = UCS_CONFIG_EMPTY_GLOBAL_LIST_ENTRY, - .tl_list = UCT_COMPONENT_TL_LIST_INITIALIZER(&uct_rocm_ipc_component), - .flags = 0, - .md_vfs_init = (uct_component_md_vfs_init_func_t)ucs_empty_function + .ipc_cache = NULL, + .lock = PTHREAD_MUTEX_INITIALIZER }; -UCT_COMPONENT_REGISTER(&uct_rocm_ipc_component); +UCT_COMPONENT_REGISTER(&uct_rocm_ipc_component.super); +UCS_STATIC_CLEANUP { + if (uct_rocm_ipc_component.ipc_cache != NULL) { + ucs_debug("Destroying ROCm IPC component cache"); + uct_rocm_ipc_destroy_cache(uct_rocm_ipc_component.ipc_cache); + uct_rocm_ipc_component.ipc_cache = NULL; + } + pthread_mutex_destroy(&uct_rocm_ipc_component.lock); +} diff --git a/src/uct/rocm/ipc/rocm_ipc_md.h b/src/uct/rocm/ipc/rocm_ipc_md.h index ebe46985493..54aa89b11da 100644 --- a/src/uct/rocm/ipc/rocm_ipc_md.h +++ b/src/uct/rocm/ipc/rocm_ipc_md.h @@ -1,5 +1,5 @@ /* - * Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED. + * Copyright (C) Advanced Micro Devices, Inc. 2019-2026. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ @@ -8,9 +8,18 @@ #include #include +#include -extern uct_component_t uct_rocm_ipc_component; +typedef struct uct_rocm_ipc_cache uct_rocm_ipc_cache_t; + +typedef struct uct_rocm_ipc_component { + uct_component_t super; + uct_rocm_ipc_cache_t *ipc_cache; + pthread_mutex_t lock; +} uct_rocm_ipc_component_t; + +extern uct_rocm_ipc_component_t uct_rocm_ipc_component; typedef struct uct_rocm_ipc_md { struct uct_md super; diff --git a/test/gtest/Makefile.am b/test/gtest/Makefile.am index f30fc8d9a0b..d5c06786f3f 100644 --- a/test/gtest/Makefile.am +++ b/test/gtest/Makefile.am @@ -299,8 +299,19 @@ endif if HAVE_HIP if HAVE_GNUXX11 +HIPCC_EXTRA_FLAGS = \ + $(HIP_CPPFLAGS) \ + -DHAVE_ROCM=1 \ + -I$(top_builddir) \ + -I$(top_srcdir)/test \ + -I$(top_srcdir)/test/gtest +include $(top_srcdir)/config/hip.am + gtest_SOURCES += \ - ucm/rocm_hooks.cc + ucm/rocm_hooks.cc \ + uct/rocm/test_kernels_uct.hip \ + uct/rocm/test_rocm_ipc_device.hip + gtest_CPPFLAGS += \ $(HIP_CPPFLAGS) gtest_CXXFLAGS += \ @@ -309,8 +320,8 @@ gtest_LDADD += \ $(HIP_LDFLAGS) \ $(HIP_LIBS) \ $(top_builddir)/src/uct/rocm/libuct_rocm.la -endif -endif +endif # HAVE_GNUXX11 +endif # HAVE_HIP noinst_HEADERS = \ common/mem_buffer.h \ diff --git a/test/gtest/uct/rocm/test_kernels_uct.h b/test/gtest/uct/rocm/test_kernels_uct.h new file mode 100644 index 00000000000..15b7f4a8b6a --- /dev/null +++ b/test/gtest/uct/rocm/test_kernels_uct.h @@ -0,0 +1,43 @@ +/** + * Copyright (c) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#ifndef ROCM_TEST_KERNELS_H_ +#define ROCM_TEST_KERNELS_H_ + +#include +#include +#include +#include +#include + +namespace rocm_uct { + +/** + * Kernel that calls the generic uct_device_ep_put API + */ +template +__global__ void test_put_kernel( + uct_device_ep_h ep, + const uct_device_local_mem_list_elem_t *src_elem, + const uct_device_mem_element_t *mem_elem, + const void *va, + uint64_t rva, + size_t length, + ucs_status_t *status_p); + +/** + * Host function to launch the PUT kernel + */ +ucs_status_t launch_uct_put(uct_device_ep_h device_ep, + const uct_device_local_mem_list_elem_t *src_elem, + const uct_device_mem_element_t *mem_elem, + const void *va, uint64_t rva, size_t length, + ucs_device_level_t level, unsigned num_threads, + unsigned num_blocks); + +} // namespace rocm_uct + +#endif diff --git a/test/gtest/uct/rocm/test_kernels_uct.hip b/test/gtest/uct/rocm/test_kernels_uct.hip new file mode 100644 index 00000000000..0f20706b5e8 --- /dev/null +++ b/test/gtest/uct/rocm/test_kernels_uct.hip @@ -0,0 +1,185 @@ +/** + * Copyright (c) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#include "test_kernels_uct.h" + +#include +#include +#include +#include +#include + +#define UCS_DEVICE_LEVEL_EXEC_ID 1 + +#define UCS_DEVICE_LEVEL_EXEC_SELECT(scope_ok, count, id) \ + ((scope_ok) ? (((count) > UCS_DEVICE_LEVEL_EXEC_ID) ? \ + ((id) == UCS_DEVICE_LEVEL_EXEC_ID) : true) : false) + +namespace rocm_uct { + +/** + * Wrapper class for a host memory result variable, that can be mapped to device + * memory and passed to a HIP kernel. + */ +template class device_result_ptr { + public: + device_result_ptr() : m_ptr(allocate(), release) + { + } + + device_result_ptr(const T &value) : m_ptr(allocate(), release) + { + *m_ptr.get() = value; + } + + T &operator*() + { + return *m_ptr.get(); + } + + T *device_ptr() + { + void *device_ptr_void; + if (hipHostGetDevicePointer(&device_ptr_void, m_ptr.get(), 0) != + hipSuccess) { + throw std::runtime_error("hipHostGetDevicePointer() failure"); + } + return static_cast(device_ptr_void); + } + + private: + static T *allocate() + { + T *ptr = nullptr; + if (hipHostMalloc(&ptr, sizeof(T), hipHostMallocMapped) != + hipSuccess) { + throw std::bad_alloc(); + } + return ptr; + } + + static void release(T *ptr) + { + hipHostFree(ptr); + } + + std::unique_ptr m_ptr; + }; + + static void synchronize() + { + if (hipDeviceSynchronize() != hipSuccess) { + throw std::runtime_error("hipDeviceSynchronize() failure"); + } + } + + static __device__ bool is_op_enabled(ucs_device_level_t level) + { + unsigned int thread_id = threadIdx.x; + unsigned int num_threads = blockDim.x; + unsigned int warp_id = thread_id / 64; // ROCm wavefront size + unsigned int num_warps = num_threads / 64; + unsigned int block_id = blockIdx.x; + unsigned int num_blocks = gridDim.x; + + switch (level) { + case UCS_DEVICE_LEVEL_THREAD: + return UCS_DEVICE_LEVEL_EXEC_SELECT(block_id == 0, num_threads, thread_id); + case UCS_DEVICE_LEVEL_WARP: + return UCS_DEVICE_LEVEL_EXEC_SELECT(block_id == 0, num_warps, warp_id); + case UCS_DEVICE_LEVEL_BLOCK: + return UCS_DEVICE_LEVEL_EXEC_SELECT(true, num_blocks, block_id); + case UCS_DEVICE_LEVEL_GRID: + return true; + } + return false; +} + +template +__global__ void +test_put_kernel(uct_device_ep_h ep, + const uct_device_local_mem_list_elem_t *src_elem, + const uct_device_mem_element_t *mem_elem, const void *va, + uint64_t rva, size_t length, ucs_status_t *status_p) +{ + uct_device_completion_t comp; + + if (is_op_enabled(level)) { + ucs_status_t status = uct_device_ep_put(ep, src_elem, mem_elem, + va, rva, length, 0, + UCT_DEVICE_FLAG_NODELAY, + &comp); + while (status == UCS_INPROGRESS) { + uct_device_ep_progress(ep); + status = uct_device_ep_check_completion(ep, &comp); + } + *status_p = status; + } +} + +ucs_status_t launch_uct_put(uct_device_ep_h device_ep, + const uct_device_local_mem_list_elem_t *src_elem, + const uct_device_mem_element_t *mem_elem, + const void *va, uint64_t rva, size_t length, + ucs_device_level_t level, unsigned num_threads, + unsigned num_blocks) +{ + device_result_ptr status = UCS_ERR_NOT_IMPLEMENTED; + hipError_t st; + + switch (level) { + case UCS_DEVICE_LEVEL_THREAD: + hipLaunchKernelGGL(test_put_kernel, + dim3(num_blocks), dim3(num_threads), 0, 0, + device_ep, src_elem, mem_elem, va, rva, length, + status.device_ptr()); + break; + case UCS_DEVICE_LEVEL_WARP: + hipLaunchKernelGGL(test_put_kernel, + dim3(num_blocks), dim3(num_threads), 0, 0, + device_ep, src_elem, mem_elem, va, rva, length, + status.device_ptr()); + break; + case UCS_DEVICE_LEVEL_BLOCK: + hipLaunchKernelGGL(test_put_kernel, + dim3(num_blocks), dim3(num_threads), 0, 0, + device_ep, src_elem, mem_elem, va, rva, length, + status.device_ptr()); + break; + case UCS_DEVICE_LEVEL_GRID: + hipLaunchKernelGGL(test_put_kernel, + dim3(num_blocks), dim3(num_threads), 0, 0, + device_ep, src_elem, mem_elem, va, rva, length, + status.device_ptr()); + break; + default: + throw std::runtime_error("Unsupported level"); + } + + st = hipGetLastError(); + if (st != hipSuccess) { + throw std::runtime_error(hipGetErrorString(st)); + } + + synchronize(); + return *status; +} + +// Explicit template instantiations for all device levels +template __global__ void test_put_kernel( + uct_device_ep_h, const uct_device_local_mem_list_elem_t*, + const uct_device_mem_element_t*, const void*, uint64_t, size_t, ucs_status_t*); +template __global__ void test_put_kernel( + uct_device_ep_h, const uct_device_local_mem_list_elem_t*, + const uct_device_mem_element_t*, const void*, uint64_t, size_t, ucs_status_t*); +template __global__ void test_put_kernel( + uct_device_ep_h, const uct_device_local_mem_list_elem_t*, + const uct_device_mem_element_t*, const void*, uint64_t, size_t, ucs_status_t*); +template __global__ void test_put_kernel( + uct_device_ep_h, const uct_device_local_mem_list_elem_t*, + const uct_device_mem_element_t*, const void*, uint64_t, size_t, ucs_status_t*); + +} // namespace rocm_uct diff --git a/test/gtest/uct/rocm/test_rocm_ipc_device.hip b/test/gtest/uct/rocm/test_rocm_ipc_device.hip new file mode 100644 index 00000000000..35a6988773e --- /dev/null +++ b/test/gtest/uct/rocm/test_rocm_ipc_device.hip @@ -0,0 +1,249 @@ +/** + * Copyright (c) Advanced Micro Devices, Inc. 2026. ALL RIGHTS RESERVED. + * + * See file LICENSE for terms. + */ + +#include +#include +#include +#include "test_kernels_uct.h" +#include +#include + +class test_rocm_ipc_rma : public uct_test { +protected: + void init() { + hipError_t res_drv; + int rocm_id; + + uct_test::init(); + + rocm_id = 0; + res_drv = hipSetDevice(rocm_id); + if (res_drv != hipSuccess) { + ucs_error("hipSetDevice returned %d.", res_drv); + return; + } + + res_drv = hipGetDevice(&m_rocm_dev); + if (res_drv != hipSuccess) { + ucs_error("hipGetDevice returned %d.", res_drv); + return; + } + + m_receiver = uct_test::create_entity(0); + m_entities.push_back(m_receiver); + + m_sender = uct_test::create_entity(0); + m_entities.push_back(m_sender); + + m_sender->connect(0, *m_receiver, 0); + } + + void cleanup() { + uct_test::cleanup(); + } + + size_t get_mem_elem_size() { + return sizeof(uct_rocm_ipc_device_mem_element_t); + } + + entity * m_sender; + entity * m_receiver; + + int m_rocm_dev; + static const uint64_t SEED1 = 0xABClu; + static const uint64_t SEED2 = 0xDEFlu; + static const unsigned WAVEFRONT_SIZE = 64; +}; + +UCS_TEST_P(test_rocm_ipc_rma, has_device_ep_capability) +{ + uct_iface_attr_t iface_attr; + + ASSERT_UCS_OK(uct_iface_query(m_sender->iface(), &iface_attr)); + EXPECT_EQ(iface_attr.cap.flags & UCT_IFACE_FLAG_DEVICE_EP, + UCT_IFACE_FLAG_DEVICE_EP); +} + +class test_rocm_ipc_rma_device : public test_rocm_ipc_rma { + protected: + void init() { + test_rocm_ipc_rma::init(); + } + + void cleanup() { + test_rocm_ipc_rma::cleanup(); + } + ucs_device_level_t get_device_level() const { + return static_cast((GetParam()->variant >> 28) & 0xF); + } + + int get_num_blocks() const { + return (GetParam()->variant >> 24) & 0xF; + } + + int get_num_threads() const { + return (GetParam()->variant >> 12) & 0xFFF; + } + int get_offset() const { + return GetParam()->variant & 0xFFF; + } + + static const unsigned base_length = 1024; + + public: + static std::vector enum_resources(const std::string& tl_name) { +/* +Parameter packing in resource.variant (uint32_t): + [31:28] device_level (uct_device_level_t, 0..15) + [27:24] num_blocks (int, 0..15) used: 1, 2 + [23:12] num_threads (int, 0..4095) used: 1, 64, 128, 256 (threads per block) + [11:0] offset (int, 0..4095) used: 0, 1, 4, 8 (send buffer offset) +*/ + static std::vector> storage; + static std::vector out; + if (!out.empty()) { + return out; + } + + std::vector base = uct_test::enum_resources(tl_name); + const ucs_device_level_t levels[] = {UCS_DEVICE_LEVEL_THREAD, + UCS_DEVICE_LEVEL_WARP, + UCS_DEVICE_LEVEL_BLOCK, + UCS_DEVICE_LEVEL_GRID}; + const int num_threads[] = {1, 64, 128, 256}; + const int num_blocks[] = {1, 2}; + const int offsets[] = {0, 1, 4, 8}; + + const size_t total = base.size() * + (sizeof(levels) / sizeof(levels[0])) * + (sizeof(num_threads) / sizeof(num_threads[0])) * + (sizeof(offsets) / sizeof(offsets[0])); + storage.reserve(total); + out.reserve(total); + + for (const resource* r : base) { + for (ucs_device_level_t dl : levels) { + for (int nt : num_threads) { + for (int off : offsets) { + for (int nb: num_blocks) { + std::unique_ptr up(new resource(*r)); + up->variant = ((static_cast(dl) & 0xF) << 28) | + ((nb & 0xF) << 24) | + ((nt & 0xFFF) << 12) | + (off & 0xFFF); + switch (dl) { + case UCS_DEVICE_LEVEL_THREAD: + up->variant_name = "thread"; + break; + case UCS_DEVICE_LEVEL_WARP: + up->variant_name = "warp"; + break; + case UCS_DEVICE_LEVEL_BLOCK: + up->variant_name = "block"; + break; + case UCS_DEVICE_LEVEL_GRID: + up->variant_name = "grid"; + break; + default: + break; + } + up->variant_name += "- nt" + std::to_string(nt) + + "- nb" + std::to_string(nb) + + "- offset" + std::to_string(off); + out.push_back(up.get()); + storage.emplace_back(std::move(up)); + } + } + } + } + } + return out; + } +}; + +UCS_TEST_P(test_rocm_ipc_rma, mem_elem_size) +{ + EXPECT_EQ(get_mem_elem_size(), sizeof(uct_rocm_ipc_device_mem_element_t)); +} + +UCS_TEST_P(test_rocm_ipc_rma, get_mem_elem_pack) +{ + static const uint64_t SEED1 = 0xABClu; + static const uint64_t SEED2 = 0xDEFlu; + size_t length = 1024; + size_t mem_elem_size = get_mem_elem_size(); + uct_device_mem_element_t *mem_elem; + + mapped_buffer sendbuf(length, SEED1, *m_sender, 0, UCS_MEMORY_TYPE_ROCM); + mapped_buffer recvbuf(length, SEED2, *m_receiver, 0, UCS_MEMORY_TYPE_ROCM); + + ASSERT_EQ(hipSuccess, hipMalloc((void **)&mem_elem, mem_elem_size)); + EXPECT_UCS_OK(uct_md_mem_elem_pack(m_sender->md(), sendbuf.memh(), + recvbuf.rkey(), mem_elem)); + hipFree(mem_elem); +} + +UCS_TEST_P(test_rocm_ipc_rma, get_device_ep) +{ + uct_device_ep_h device_ep; + + ASSERT_UCS_OK(uct_ep_get_device_ep(m_sender->ep(0), &device_ep)); +} + +_UCT_INSTANTIATE_TEST_CASE(test_rocm_ipc_rma, rocm_ipc) + +UCS_TEST_P(test_rocm_ipc_rma_device, device_put) +{ + size_t length = base_length + get_offset(); + ucs_device_level_t device_level = get_device_level(); + unsigned num_threads = get_num_threads(); + unsigned num_blocks = get_num_blocks(); + + if (device_level == UCS_DEVICE_LEVEL_GRID) { + GTEST_SKIP() << "Grid level is not supported"; + } + + if ((device_level == UCS_DEVICE_LEVEL_WARP) && (num_threads < 64)) { + GTEST_SKIP() << "Warp level is not supported for less than 64 threads"; + } + + mapped_buffer sendbuf(length, SEED1, *m_sender, 0, UCS_MEMORY_TYPE_ROCM); + mapped_buffer recvbuf(length, SEED2, *m_receiver, 0, UCS_MEMORY_TYPE_ROCM); + + uct_device_local_mem_list_elem_t src_elem_host; + ASSERT_UCS_OK(uct_md_mem_elem_pack(m_sender->md(), sendbuf.memh(), + recvbuf.rkey(), + &src_elem_host.uct_mem_element)); + + uct_device_local_mem_list_elem_t *src_elem; + ASSERT_EQ(hipSuccess, + hipMalloc((void**)&src_elem, + sizeof(uct_device_local_mem_list_elem_t))); + ASSERT_EQ(hipSuccess, + hipMemcpy(src_elem, &src_elem_host, + sizeof(uct_device_local_mem_list_elem_t), + hipMemcpyHostToDevice)); + + uct_device_mem_element_t *mem_elem; + ASSERT_EQ(hipSuccess, hipMalloc((void**)&mem_elem, + sizeof(uct_device_mem_element_t))); + ASSERT_EQ(hipSuccess, hipMemcpy(mem_elem, &src_elem_host.uct_mem_element, + sizeof(uct_device_mem_element_t), + hipMemcpyHostToDevice)); + + uct_device_ep_h device_ep; + ASSERT_UCS_OK(uct_ep_get_device_ep(m_sender->ep(0), &device_ep)); + ASSERT_UCS_OK(rocm_uct::launch_uct_put(device_ep, src_elem, mem_elem, + sendbuf.ptr(), + (uintptr_t)recvbuf.ptr(), length, + device_level, num_threads, + num_blocks)); + recvbuf.pattern_check(SEED1); + hipFree(src_elem); + hipFree(mem_elem); +} + +_UCT_INSTANTIATE_TEST_CASE(test_rocm_ipc_rma_device, rocm_ipc)