diff --git a/apps/nccl/include/mscclpp/nccl.h b/apps/nccl/include/mscclpp/nccl.h index f71da4a62..2e5a48cd8 100644 --- a/apps/nccl/include/mscclpp/nccl.h +++ b/apps/nccl/include/mscclpp/nccl.h @@ -248,17 +248,10 @@ typedef enum { ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, -#if defined(__CUDA_BF16_TYPES_EXIST__) && defined(__CUDA_FP8_TYPES_EXIST__) ncclBfloat16 = 9, - ncclFp8E4M3 = 10, - ncclFp8E5M2 = 11, + ncclFloat8e4m3 = 10, + ncclFloat8e5m2 = 11, ncclNumTypes = 12 -#elif defined(__CUDA_BF16_TYPES_EXIST__) - ncclBfloat16 = 9, - ncclNumTypes = 10 -#else - ncclNumTypes = 9 -#endif } ncclDataType_t; static inline size_t ncclTypeSize(ncclDataType_t type) { @@ -278,15 +271,11 @@ static inline size_t ncclTypeSize(ncclDataType_t type) { return 4; case ncclFloat64: return 8; -#if defined(__CUDA_BF16_TYPES_EXIST__) case ncclBfloat16: return 2; -#endif // defined(__CUDA_BF16_TYPES_EXIST__) -#if defined(__CUDA_FP8_TYPES_EXIST__) - case ncclFp8E4M3: - case ncclFp8E5M2: + case ncclFloat8e4m3: + case ncclFloat8e5m2: return 1; -#endif // defined(__CUDA_FP8_TYPES_EXIST__) case ncclNumTypes: return 0; } diff --git a/apps/nccl/src/allreduce.cu b/apps/nccl/src/allreduce.cu index bac933a01..d5b8200a4 100644 --- a/apps/nccl/src/allreduce.cu +++ b/apps/nccl/src/allreduce.cu @@ -71,13 +71,20 @@ struct NvlsAdapter { mscclpp::DeviceHandle* nvlsOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t, int rank, int nRanksPerNode, int, size_t nelems, cudaStream_t stream, uint32_t*, uint32_t*, uint32_t*, uint32_t) { - using ChannelType = mscclpp::DeviceHandle; - int nBlocks = nRanksPerNode; - int nThreadsPerBlock = 1024; - allreduce9<<>>((ChannelType*)memoryChannels, nvlsChannels, nvlsOutChannels, - channelInOffset, channelOutOffset, nelems * sizeof(T), rank, - nRanksPerNode); - return cudaGetLastError(); +#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS + if constexpr (std::is_same_v || std::is_same_v) { + return cudaErrorNotSupported; + } else +#endif + { + using ChannelType = mscclpp::DeviceHandle; + int nBlocks = nRanksPerNode; + int nThreadsPerBlock = 1024; + allreduce9<<>>((ChannelType*)memoryChannels, nvlsChannels, + nvlsOutChannels, channelInOffset, channelOutOffset, + nelems * sizeof(T), rank, nRanksPerNode); + return cudaGetLastError(); + } } }; @@ -88,21 +95,28 @@ struct NvlsWithCopyAdapter { mscclpp::DeviceHandle*, size_t, size_t, size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t nelems, cudaStream_t stream, uint32_t*, uint32_t*, uint32_t*, uint32_t) { - using ChannelType = mscclpp::DeviceHandle; - if (sizeof(T) * nelems < (1 << 24)) { - int nBlocks = nRanksPerNode * 4; - int nThreadsPerBlock = 1024; - allreduce10<<>>(input, scratch, output, (ChannelType*)memoryChannels, - nvlsChannels, nelems * sizeof(T), scratchBufferSize, - rank, nRanksPerNode); - } else { - int nBlocks = nRanksPerNode * 5; - int nThreadsPerBlock = 1024; - allreduce11<<>>(input, scratch, output, (ChannelType*)memoryChannels, - nvlsChannels, nelems * sizeof(T), scratchBufferSize, - rank, nRanksPerNode); +#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS + if constexpr (std::is_same_v || std::is_same_v) { + return cudaErrorNotSupported; + } else +#endif + { + using ChannelType = mscclpp::DeviceHandle; + if (sizeof(T) * nelems < (1 << 24)) { + int nBlocks = nRanksPerNode * 4; + int nThreadsPerBlock = 1024; + allreduce10<<>>(input, scratch, output, (ChannelType*)memoryChannels, + nvlsChannels, nelems * sizeof(T), scratchBufferSize, + rank, nRanksPerNode); + } else { + int nBlocks = nRanksPerNode * 5; + int nThreadsPerBlock = 1024; + allreduce11<<>>(input, scratch, output, (ChannelType*)memoryChannels, + nvlsChannels, nelems * sizeof(T), scratchBufferSize, + rank, nRanksPerNode); + } + return cudaGetLastError(); } - return cudaGetLastError(); } }; @@ -154,6 +168,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) { #if defined(__CUDA_BF16_TYPES_EXIST__) } else if (dtype == ncclBfloat16) { return Adapter::call; +#endif +#if defined(__FP8_TYPES_EXIST__) + } else if (dtype == ncclFloat8e4m3) { + return Adapter::call; + } else if (dtype == ncclFloat8e5m2) { + return Adapter::call; #endif } else if (dtype == ncclInt32 || dtype == ncclUint32) { return Adapter::call; @@ -168,6 +188,12 @@ AllreduceFunc dispatch(ncclRedOp_t op, ncclDataType_t dtype) { #if defined(__CUDA_BF16_TYPES_EXIST__) } else if (dtype == ncclBfloat16) { return Adapter::call; +#endif +#if defined(__FP8_TYPES_EXIST__) + } else if (dtype == ncclFloat8e4m3) { + return Adapter::call; + } else if (dtype == ncclFloat8e5m2) { + return Adapter::call; #endif } else if (dtype == ncclInt32 || dtype == ncclUint32) { return Adapter::call; diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 4194cb19f..84e904eee 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -127,6 +127,208 @@ __forceinline__ __device__ __bfloat162 min_elements(__bfloat162 a, __bfloat162 b return __hmin2(a, b); } +#if defined(__FP8_TYPES_EXIST__) +// FP8 E4M3 clipping function +template <> +__forceinline__ __device__ __fp8_e4m3 clip(__fp8_e4m3 val) { + // FP8 E4M3 has range [-448, 448], no infinities + // Built-in saturation in FP8 arithmetic + return val; +} + +// FP8 E5M2 clipping function - prevent infinities by clamping to max finite value +template <> +__forceinline__ __device__ __fp8_e5m2 clip(__fp8_e5m2 val) { + // FP8 E5M2 has infinities - clamp to max finite value to prevent overflow + // Max finite value for E5M2 is 57344.0f (0x7B), min is -57344.0f (0xFB) + float fval = float(val); + fval = fmaxf(fval, -57344.0f); + fval = fminf(fval, 57344.0f); + return __fp8_e5m2(fval); +} + +// FP8 E4M3 addition using __hadd for efficiency (single element) +template +__forceinline__ __device__ __fp8_e4m3 add_elements(__fp8_e4m3 a, __fp8_e4m3 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // Optimized assembly for gfx942 + float2 v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v) + : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b.__x, 0))); + return __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.x, ival, false); +#elif !defined(__HIP_PLATFORM_AMD__) + // NVIDIA CUDA FP8 addition (CUDA 11.8+) + __fp8_e4m3 result = __fp8_e4m3(__hadd(__half(a), __half(b))); + return UseClip ? clip(result) : result; +#else + // Fallback for non-gfx942 HIP platforms + __fp8_e4m3 result = __fp8_e4m3(float(a) + float(b)); + return UseClip ? clip(result) : result; +#endif +} + +// FP8 E4M3 vectorized addition for 2 elements +template +__forceinline__ __device__ __fp8x2_e4m3 add_elements(__fp8x2_e4m3 a, __fp8x2_e4m3 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + float2 v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v) + : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b, 0))); + return __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, ival, false); +#elif !defined(__HIP_PLATFORM_AMD__) + // CUDA: Convert to half2, add using optimized __hadd2, convert back + __fp8x2_e4m3 result = __fp8x2_e4m3(__hadd2(__half2(a), __half2(b))); + return result; +#else + // Fallback for non-gfx942 HIP: element-wise using single-element operations + union { + __fp8_e4m3 fp8[2]; + __fp8x2_e4m3 fp8x2; + } ua, ub, result; + ua.fp8x2 = a; + ub.fp8x2 = b; + result.fp8[0] = add_elements(ua.fp8[0], ub.fp8[0]); + result.fp8[1] = add_elements(ua.fp8[1], ub.fp8[1]); + return result.fp8x2; +#endif +} + +// FP8 E4M3 vectorized addition for 4 elements (via 2x __fp8x2_e4m3) +template +__forceinline__ __device__ __fp8x4_e4m3 add_elements(__fp8x4_e4m3 a, __fp8x4_e4m3 b) { + // Process as two __fp8x2_e4m3 using add_elements for 2 elements + __fp8x2_e4m3* a_pair = reinterpret_cast<__fp8x2_e4m3*>(&a); + __fp8x2_e4m3* b_pair = reinterpret_cast<__fp8x2_e4m3*>(&b); + + __fp8x2_e4m3 result[2]; + result[0] = add_elements(a_pair[0], b_pair[0]); + result[1] = add_elements(a_pair[1], b_pair[1]); + + return *reinterpret_cast<__fp8x4_e4m3*>(result); +} + +// FP8 E5M2 addition using __hadd for efficiency (single element) +template +__forceinline__ __device__ __fp8_e5m2 add_elements(__fp8_e5m2 a, __fp8_e5m2 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // Optimized assembly for gfx942 (bfloat8) + float2 v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v) + : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b.__x, 0))); + return __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.x, ival, false); +#elif !defined(__HIP_PLATFORM_AMD__) + // NVIDIA CUDA FP8 addition + __fp8_e5m2 result = __fp8_e5m2(__hadd(__half(a), __half(b))); + return UseClip ? clip(result) : result; +#else + // Fallback for non-gfx942 HIP platforms + __fp8_e5m2 result = __fp8_e5m2(float(a) + float(b)); + return UseClip ? clip(result) : result; +#endif +} + +#if !defined(__HIP_PLATFORM_AMD__) +// FP8 E5M2 vectorized addition for 2 elements (CUDA only) +template +__forceinline__ __device__ __fp8x2_e5m2 add_elements(__fp8x2_e5m2 a, __fp8x2_e5m2 b) { + // CUDA: Convert to half2, add using optimized __hadd2, convert back + __fp8x2_e5m2 result = __fp8x2_e5m2(__hadd2(__half2(a), __half2(b))); + return result; +} + +// FP8 E5M2 vectorized addition for 4 elements (CUDA only - via 2x __fp8x2_e5m2) +template +__forceinline__ __device__ __fp8x4_e5m2 add_elements(__fp8x4_e5m2 a, __fp8x4_e5m2 b) { + // Process as two __fp8x2_e5m2 using add_elements for 2 elements + __fp8x2_e5m2* a_pair = reinterpret_cast<__fp8x2_e5m2*>(&a); + __fp8x2_e5m2* b_pair = reinterpret_cast<__fp8x2_e5m2*>(&b); + + __fp8x2_e5m2 result[2]; + result[0] = add_elements(a_pair[0], b_pair[0]); + result[1] = add_elements(a_pair[1], b_pair[1]); + + return *reinterpret_cast<__fp8x4_e5m2*>(result); +} +#endif // !defined(__HIP_PLATFORM_AMD__) + +// FP8 E4M3 min operation (single element) +template <> +__forceinline__ __device__ __fp8_e4m3 min_elements(__fp8_e4m3 a, __fp8_e4m3 b) { +#if defined(__HIP_PLATFORM_AMD__) + return __fp8_e4m3(fminf(float(a), float(b))); +#else + return __fp8_e4m3(__hmin(__half(a), __half(b))); +#endif +} + +// FP8 E4M3 vectorized min for 2 elements +__forceinline__ __device__ __fp8x2_e4m3 min_elements(__fp8x2_e4m3 a, __fp8x2_e4m3 b) { +#if defined(__HIP_PLATFORM_AMD__) + // HIP implementation: use union and process element-wise + union { + __fp8_e4m3 fp8[2]; + __fp8x2_e4m3 fp8x2; + } ua, ub, result; + ua.fp8x2 = a; + ub.fp8x2 = b; + result.fp8[0] = min_elements(ua.fp8[0], ub.fp8[0]); + result.fp8[1] = min_elements(ua.fp8[1], ub.fp8[1]); + return result.fp8x2; +#else + return __fp8x2_e4m3(__hmin2(__half2(a), __half2(b))); +#endif +} + +// FP8 E4M3 vectorized min for 4 elements +__forceinline__ __device__ __fp8x4_e4m3 min_elements(__fp8x4_e4m3 a, __fp8x4_e4m3 b) { + // Process as two __fp8x2_e4m3 using min_elements for 2 elements + __fp8x2_e4m3* a_pair = reinterpret_cast<__fp8x2_e4m3*>(&a); + __fp8x2_e4m3* b_pair = reinterpret_cast<__fp8x2_e4m3*>(&b); + + __fp8x2_e4m3 result[2]; + result[0] = min_elements(a_pair[0], b_pair[0]); + result[1] = min_elements(a_pair[1], b_pair[1]); + + return *reinterpret_cast<__fp8x4_e4m3*>(result); +} + +// FP8 E5M2 min operation (single element) +template <> +__forceinline__ __device__ __fp8_e5m2 min_elements(__fp8_e5m2 a, __fp8_e5m2 b) { +#if defined(__HIP_PLATFORM_AMD__) + return __fp8_e5m2(fminf(float(a), float(b))); +#else + return __fp8_e5m2(__hmin(__half(a), __half(b))); +#endif +} + +#if !defined(__HIP_PLATFORM_AMD__) +// FP8 E5M2 vectorized min for 2 elements (CUDA only) +__forceinline__ __device__ __fp8x2_e5m2 min_elements(__fp8x2_e5m2 a, __fp8x2_e5m2 b) { + return __fp8x2_e5m2(__hmin2(__half2(a), __half2(b))); +} + +// FP8 E5M2 vectorized min for 4 elements (CUDA only) +__forceinline__ __device__ __fp8x4_e5m2 min_elements(__fp8x4_e5m2 a, __fp8x4_e5m2 b) { + // Process as two __fp8x2_e5m2 using min_elements for 2 elements + __fp8x2_e5m2* a_pair = reinterpret_cast<__fp8x2_e5m2*>(&a); + __fp8x2_e5m2* b_pair = reinterpret_cast<__fp8x2_e5m2*>(&b); + + __fp8x2_e5m2 result[2]; + result[0] = min_elements(a_pair[0], b_pair[0]); + result[1] = min_elements(a_pair[1], b_pair[1]); + + return *reinterpret_cast<__fp8x4_e5m2*>(result); +} +#endif // !defined(__HIP_PLATFORM_AMD__) +#endif // __FP8_TYPES_EXIST__ + template __forceinline__ __device__ T cal_elements(T a, T b) { if constexpr (OpType == SUM) { @@ -161,10 +363,94 @@ __forceinline__ __device__ int cal_vectors_helper(int a, int b) { return bit_cast(cal_elements(bit_cast(a), bit_cast(b))); } +#if defined(__HIP_PLATFORM_AMD__) && defined(__FP8_TYPES_EXIST__) && defined(__gfx942__) +// Helper function to perform FP8 vector addition - dispatches based on scalar type +// Uses AMD builtins from hip/amd_detail/amd_hip_fp8.h: +// - __builtin_amdgcn_cvt_pk_f32_fp8/bf8: Convert 2 FP8 values to 2 floats +// - __builtin_amdgcn_cvt_pk_fp8/bf8_f32: Convert 2 floats to 2 FP8 values +// The 'word' parameter (false/true) selects low/high 16-bit word from uint32_t +template +__forceinline__ __device__ int add_fp8x4_hip(int a, int b) { + uint32_t a32 = static_cast(a); + uint32_t b32 = static_cast(b); + + float2 v_low, v_high; + uint32_t ival = 0; + + if constexpr (std::is_same_v) { + // E4M3 using fp8 conversion - process low word (false) and high word (true) + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v_low) + : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a32, false)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b32, false))); + uint16_t result_low = __builtin_amdgcn_cvt_pk_fp8_f32(v_low.x, v_low.y, ival, false); + + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v_high) + : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a32, true)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b32, true))); + uint16_t result_high = __builtin_amdgcn_cvt_pk_fp8_f32(v_high.x, v_high.y, ival, false); + + uint32_t result = (static_cast(result_high) << 16) | result_low; + return static_cast(result); + } else { // __fp8_e5m2 + // E5M2 using bf8 conversion - process low word (false) and high word (true) + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v_low) + : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a32, false)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b32, false))); + uint16_t result_low = __builtin_amdgcn_cvt_pk_bf8_f32(v_low.x, v_low.y, ival, false); + + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v_high) + : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a32, true)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b32, true))); + uint16_t result_high = __builtin_amdgcn_cvt_pk_bf8_f32(v_high.x, v_high.y, ival, false); + + uint32_t result = (static_cast(result_high) << 16) | result_low; + return static_cast(result); + } +} +#endif + template __forceinline__ __device__ DataType cal_vectors(DataType a, DataType b) { - using CompType = typename std::conditional_t, __half2, - std::conditional_t, __bfloat162, T>>; +#if defined(__HIP_PLATFORM_AMD__) && defined(__FP8_TYPES_EXIST__) && defined(__gfx942__) + // For FP8 types on HIP gfx942, use specialized helper that dispatches based on scalar type + if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (OpType == SUM) { + if constexpr (std::is_same_v || std::is_same_v) { + // Handle int/uint32_t (4 FP8 elements) + return add_fp8x4_hip(a, b); + } else if constexpr (std::is_same_v) { + // Handle int4 (16 FP8 elements) - process as 4 ints + int4 ret; + ret.w = add_fp8x4_hip(a.w, b.w); + ret.x = add_fp8x4_hip(a.x, b.x); + ret.y = add_fp8x4_hip(a.y, b.y); + ret.z = add_fp8x4_hip(a.z, b.z); + return ret; + } else if constexpr (std::is_same_v) { + // Handle uint2 (8 FP8 elements) - process as 2 ints + uint2 ret; + ret.x = add_fp8x4_hip(a.x, b.x); + ret.y = add_fp8x4_hip(a.y, b.y); + return ret; + } + } + } +#endif + + // Define the vectorized computation type based on the element type + using CompType = typename std::conditional_t< + std::is_same_v, __half2, + std::conditional_t, __bfloat162, +#if defined(__FP8_TYPES_EXIST__) + std::conditional_t, __fp8x4_e4m3, + std::conditional_t, __fp8x4_e5m2, +#endif + T +#if defined(__FP8_TYPES_EXIST__) + >>>>; +#else + >>; +#endif return cal_vectors_helper(a, b); } @@ -175,7 +461,7 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, int worldSize, size_t nelems, uint32_t* deviceFlag, uint32_t numScratchBuff) { // This version of allreduce only works for single nodes if (worldSize != nRanksPerNode) return; - if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); + if (sizeof(T) == 2 || sizeof(T) == 1) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); const int nPeers = nRanksPerNode - 1; uint32_t flag = deviceFlag[blockIdx.x]; @@ -253,7 +539,7 @@ __global__ void __launch_bounds__(1024, 1) &event_buffer_head); #endif - if (sizeof(T) == 2) + if (sizeof(T) == 2 || sizeof(T) == 1) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); else nelems = nelems / (sizeof(int) / sizeof(T)); diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 4ec8a1785..7f52107f9 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -212,6 +212,16 @@ static ncclResult_t executeWithPlan(std::shared_ptr executor, executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, sendBytes, recvBytes, mscclpp::DataType::BFLOAT16, *plan, stream); break; +#if defined(__FP8_TYPES_EXIST__) + case ncclFloat8e4m3: + executor->execute(rank, (__fp8_e4m3*)sendbuff, (__fp8_e4m3*)recvbuff, sendBytes, recvBytes, + mscclpp::DataType::FP8_E4M3, *plan, stream); + break; + case ncclFloat8e5m2: + executor->execute(rank, (__fp8_e5m2*)sendbuff, (__fp8_e5m2*)recvbuff, sendBytes, recvBytes, + mscclpp::DataType::FP8_E5M2, *plan, stream); + break; +#endif case ncclInt32: case ncclUint32: executor->execute(rank, (int*)sendbuff, (int*)recvbuff, sendBytes, recvBytes, mscclpp::DataType::UINT32, *plan, @@ -273,15 +283,26 @@ static void registerCustomizedAlgo() { collectionBuilder->addAlgorithmBuilder(allreduceNvlsPacketAlgo); } +static std::pair getDeviceComputeCapability() { + int device; + CUDACHECK(cudaGetDevice(&device)); + int major = 0, minor = 0; + CUDACHECK(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device)); + CUDACHECK(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device)); + return std::make_pair(major, minor); +} + static mscclpp::Algorithm algoSelector( const std::unordered_map>& algoMapByCollective, - std::string collective, const void* input, void* output, size_t messageSize, int nRanksPerNode, int worldSize) { + std::string collective, const void* input, void* output, size_t messageSize, int dtype, int nRanksPerNode, + int worldSize) { if (nRanksPerNode != worldSize) { // Fallback to nccl/rccl when multi-node return mscclpp::Algorithm(); } - static bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache; - static bool isNvlsSupported = mscclpp::isNvlsSupported(); + static const bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache; + static const bool isNvlsSupported = mscclpp::isNvlsSupported(); + static const std::pair deviceComputeCapability = getDeviceComputeCapability(); bool isCuMemMapAllocated = mscclpp::isCuMemMapAllocated(const_cast(input)) && mscclpp::isCuMemMapAllocated(output); bool useNvlsWithZeroCopy = isNvlsSupported && !mscclppDisableChannelCache && isCuMemMapAllocated; @@ -299,16 +320,24 @@ static mscclpp::Algorithm algoSelector( } } if (collective == "allreduce") { - if (messageSize <= (1 << 15) && isNvlsSupported) { + bool useNvls = isNvlsSupported; + bool isFp8 = dtype == ncclFloat8e4m3 || dtype == ncclFloat8e5m2; +#if !defined(__HIP_PLATFORM_AMD__) + if (isFp8 && deviceComputeCapability.first < 10) { + // NVLS does not support FP8 on devices with compute capability < 10 + useNvls = false; + } +#endif + if (messageSize <= (1 << 15) && useNvls) { return algoMapByCollective.at(collective).at("default_allreduce_nvls_packet"); } if (messageSize <= (1 << 16) || (messageSize <= (1 << 20) && !useNvlsWithZeroCopy)) { return algoMapByCollective.at(collective).at("default_allreduce_packet"); } - if (useNvlsWithZeroCopy) { + if (useNvls && useNvlsWithZeroCopy) { return algoMapByCollective.at(collective).at("default_allreduce_nvls"); } - if (mscclpp::isNvlsSupported()) { + if (useNvls) { return algoMapByCollective.at(collective).at("default_allreduce_nvls_with_copy"); } #if defined(__HIP_PLATFORM_AMD__) @@ -631,8 +660,8 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t return executeWithPlan(comm->executor, rank, datatype, sendbuff, recvbuff, bytes, bytes, plan, stream); } auto algo = comm->algorithmCollection->selectAlgorithm( - "broadcast", sendbuff, recvbuff, count * ncclTypeSize(datatype), comm->comm->bootstrap()->getNranksPerNode(), - comm->comm->bootstrap()->getNranks()); + "broadcast", sendbuff, recvbuff, count * ncclTypeSize(datatype), datatype, + comm->comm->bootstrap()->getNranksPerNode(), comm->comm->bootstrap()->getNranks()); if (!algo.isEmpty()) { std::unordered_map> extras{ {"root", std::make_shared(root)}, @@ -692,8 +721,8 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t } auto algo = comm->algorithmCollection->selectAlgorithm( - "allreduce", sendbuff, recvbuff, count * ncclTypeSize(datatype), comm->comm->bootstrap()->getNranksPerNode(), - comm->comm->bootstrap()->getNranks()); + "allreduce", sendbuff, recvbuff, count * ncclTypeSize(datatype), datatype, + comm->comm->bootstrap()->getNranksPerNode(), comm->comm->bootstrap()->getNranks()); if (!algo.isEmpty()) { std::unordered_map> extras{ {"op", std::make_shared(reductionOperation)}, @@ -809,7 +838,7 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t } auto algo = comm->algorithmCollection->selectAlgorithm( - "allgather", sendbuff, recvbuff, nRank * sendcount * ncclTypeSize(datatype), + "allgather", sendbuff, recvbuff, nRank * sendcount * ncclTypeSize(datatype), datatype, comm->comm->bootstrap()->getNranksPerNode(), comm->comm->bootstrap()->getNranks()); if (!algo.isEmpty()) { std::unordered_map> extras = { diff --git a/examples/customized-collective-algorithm/customized_allgather.cu b/examples/customized-collective-algorithm/customized_allgather.cu index 798489245..a2f6abaca 100644 --- a/examples/customized-collective-algorithm/customized_allgather.cu +++ b/examples/customized-collective-algorithm/customized_allgather.cu @@ -201,7 +201,7 @@ void worker(int rank, int worldSize, ncclUniqueId id) { mscclpp::AlgorithmCollectionBuilder::getInstance()->setAlgorithmSelector( [](const std::unordered_map>& algoMapByCollective, - std::string collective, const void* input, void* output, size_t messageSize, int nRanksPerNode, + std::string collective, const void* input, void* output, size_t messageSize, int dtype, int nRanksPerNode, int worldSize) { if (collective != "allgather") { return mscclpp::Algorithm(); diff --git a/include/mscclpp/algorithm.hpp b/include/mscclpp/algorithm.hpp index 34e10f8f9..1784a2664 100644 --- a/include/mscclpp/algorithm.hpp +++ b/include/mscclpp/algorithm.hpp @@ -115,7 +115,8 @@ class AlgorithmBuilder { using AlgoSelectFunc = std::function>& algoMapByCollective, - std::string collective, const void* input, void* output, size_t messageSize, int nRanksPerNode, int worldSize)>; + std::string collective, const void* input, void* output, size_t messageSize, int dtype, int nRanksPerNode, + int worldSize)>; class AlgorithmCollection { public: @@ -126,11 +127,12 @@ class AlgorithmCollection { /// @param input The input buffer. /// @param output The output buffer. /// @param messageSize The message size. + /// @param dtype The data type. Please refer to ncclDataType_t for the definition. /// @param nRanksPerNode The number of ranks per node. /// @param worldSize The total number of ranks. /// @return The selected algorithm. If no suitable algorithm is found, an empty Algorithm object is returned. Algorithm selectAlgorithm(const std::string& collective, const void* input, void* output, size_t messageSize, - int nRanksPerNode, int worldSize); + int dtype, int nRanksPerNode, int worldSize); /// @brief Register a new algorithm. /// @param collective The collective operation name. diff --git a/include/mscclpp/executor.hpp b/include/mscclpp/executor.hpp index 5b6b9c922..f408c7448 100644 --- a/include/mscclpp/executor.hpp +++ b/include/mscclpp/executor.hpp @@ -17,6 +17,8 @@ enum class DataType { FLOAT16, FLOAT32, BFLOAT16, + FP8_E4M3, // Add FP8 E4M3 type + FP8_E5M2, // Add FP8 E5M2 type }; enum class PacketType { diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index 97b278e38..fa822fc7c 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -41,6 +41,9 @@ constexpr auto cudaMemcpyHostToDevice = hipMemcpyHostToDevice; constexpr auto cudaMemcpyDeviceToHost = hipMemcpyDeviceToHost; constexpr auto cudaIpcMemLazyEnablePeerAccess = hipIpcMemLazyEnablePeerAccess; +constexpr auto cudaDevAttrComputeCapabilityMajor = hipDeviceAttributeComputeCapabilityMajor; +constexpr auto cudaDevAttrComputeCapabilityMinor = hipDeviceAttributeComputeCapabilityMinor; + constexpr auto CU_MEM_ALLOCATION_TYPE_PINNED = hipMemAllocationTypePinned; constexpr auto CU_MEM_LOCATION_TYPE_DEVICE = hipMemLocationTypeDevice; constexpr auto CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = hipMemHandleTypePosixFileDescriptor; @@ -64,6 +67,7 @@ constexpr auto CU_MEM_ALLOC_GRANULARITY_MINIMUM = hipMemAllocationGranularityMin #define cudaGetDevice(...) hipGetDevice(__VA_ARGS__) #define cudaGetDeviceCount(...) hipGetDeviceCount(__VA_ARGS__) #define cudaGetDeviceProperties(...) hipGetDeviceProperties(__VA_ARGS__) +#define cudaDeviceGetAttribute(...) hipDeviceGetAttribute(__VA_ARGS__) #define cudaGetLastError(...) hipGetLastError(__VA_ARGS__) #define cudaSetDevice(...) hipSetDevice(__VA_ARGS__) #define cudaDeviceSynchronize(...) hipDeviceSynchronize(__VA_ARGS__) diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 1b4bc8a51..3b2aba84e 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -8,12 +8,31 @@ #include #include +#include using __bfloat16 = __hip_bfloat16; using __bfloat162 = __hip_bfloat162; #define __CUDA_BF16_TYPES_EXIST__ -#else +// AMD FP8 support - hip_fp8.h provides __hip_fp8_e4m3_fnuz and __hip_fp8_e5m2_fnuz +// Only available on gfx942 and newer architectures (ROCm 6.0+) +#if defined(HIP_VERSION_MAJOR) && (HIP_VERSION_MAJOR >= 6) +#include + +// Create aliases matching CUDA naming convention for cross-platform compatibility +using __fp8_e4m3 = __hip_fp8_e4m3_fnuz; +using __fp8_e5m2 = __hip_fp8_e5m2_fnuz; + +// HIP FP8 vector types use storage types (from hip/amd_detail/amd_hip_fp8.h): +using __fp8x2_e4m3 = __hip_fp8x2_storage_t; // uint16_t +using __fp8x2_e5m2 = __hip_fp8x2_storage_t; // uint16_t +using __fp8x4_e4m3 = __hip_fp8x4_storage_t; // uint32_t +using __fp8x4_e5m2 = __hip_fp8x4_storage_t; // uint32_t + +#define __FP8_TYPES_EXIST__ +#endif // HIP_VERSION_MAJOR >= 6 + +#else // NVIDIA #include #include @@ -22,6 +41,13 @@ using __bfloat162 = __hip_bfloat162; #endif #if (CUDART_VERSION >= 11080) #include +using __fp8_e4m3 = __nv_fp8_e4m3; +using __fp8_e5m2 = __nv_fp8_e5m2; +using __fp8x2_e4m3 = __nv_fp8x2_e4m3; +using __fp8x2_e5m2 = __nv_fp8x2_e5m2; +using __fp8x4_e4m3 = __nv_fp8x4_e4m3; +using __fp8x4_e5m2 = __nv_fp8x4_e5m2; +#define __FP8_TYPES_EXIST__ #endif using __bfloat16 = __nv_bfloat16; @@ -89,6 +115,16 @@ using bf16x4 = VectorType<__bfloat16, 4>; using f16x8 = VectorType<__half, 8>; using bf16x8 = VectorType<__bfloat16, 8>; +#if defined(__FP8_TYPES_EXIST__) +// FP8 vector types +using fp8_e4m3x2 = VectorType<__fp8_e4m3, 2>; +using fp8_e4m3x4 = VectorType<__fp8_e4m3, 4>; +using fp8_e4m3x8 = VectorType<__fp8_e4m3, 8>; +using fp8_e5m2x2 = VectorType<__fp8_e5m2, 2>; +using fp8_e5m2x4 = VectorType<__fp8_e5m2, 4>; +using fp8_e5m2x8 = VectorType<__fp8_e5m2, 8>; +#endif + } // namespace mscclpp #endif // MSCCLPP_GPU_DATA_TYPES_HPP_ diff --git a/src/algorithm.cc b/src/algorithm.cc index 54c1fb276..e8966df15 100644 --- a/src/algorithm.cc +++ b/src/algorithm.cc @@ -66,14 +66,14 @@ void AlgorithmCollection::registerAlgorithm(const std::string collective, const } Algorithm AlgorithmCollection::selectAlgorithm(const std::string& collective, const void* input, void* output, - size_t messageSize, int nRanksPerNode, int worldSize) { + size_t messageSize, int dtype, int nRanksPerNode, int worldSize) { Algorithm algo; if (algoSelector_) { - algo = algoSelector_(algoMapByCollective_, collective, input, output, messageSize, nRanksPerNode, worldSize); + algo = algoSelector_(algoMapByCollective_, collective, input, output, messageSize, dtype, nRanksPerNode, worldSize); } if (algo.isEmpty()) { - algo = - fallbackAlgoSelector_(algoMapByCollective_, collective, input, output, messageSize, nRanksPerNode, worldSize); + algo = fallbackAlgoSelector_(algoMapByCollective_, collective, input, output, messageSize, dtype, nRanksPerNode, + worldSize); } return algo; } diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 9ccc6b9be..0bc4c1a9c 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -51,6 +51,129 @@ MSCCLPP_DEVICE_INLINE __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) { return __hadd2(a, b); } +#if defined(__FP8_TYPES_EXIST__) +// FP8 E4M3 addition using __hadd (single element) +template <> +MSCCLPP_DEVICE_INLINE __fp8_e4m3 add_elements(__fp8_e4m3 a, __fp8_e4m3 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // Optimized assembly for gfx942 + float2 v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v) + : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b.__x, 0))); + return __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.x, ival, false); +#else + return __fp8_e4m3(__hadd(__half(a), __half(b))); +#endif +} + +// FP8 E5M2 addition using __hadd (single element) - must come before helper functions +template <> +MSCCLPP_DEVICE_INLINE __fp8_e5m2 add_elements(__fp8_e5m2 a, __fp8_e5m2 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // Optimized assembly for gfx942 (bfloat8) + float2 v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v) + : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a.__x, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b.__x, 0))); + return __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.x, ival, false); +#else + return __fp8_e5m2(__hadd(__half(a), __half(b))); +#endif +} + +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) +// HIP gfx942 platform: Helper functions for vectorized FP8 operations +// We use separate function names because __fp8x2_e4m3 and __fp8x2_e5m2 are both uint16_t + +// E4M3 vectorized addition for 2 elements +MSCCLPP_DEVICE_INLINE uint16_t add_fp8x2_e4m3(uint16_t a, uint16_t b) { + float2 v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v) + : "v"(__builtin_amdgcn_cvt_pk_f32_fp8(a, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_fp8(b, 0))); + return __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, ival, false); +} + +// E4M3 vectorized addition for 4 elements +MSCCLPP_DEVICE_INLINE uint32_t add_fp8x4_e4m3(uint32_t a, uint32_t b) { + uint16_t a_low = a & 0xFFFF; + uint16_t a_high = (a >> 16) & 0xFFFF; + uint16_t b_low = b & 0xFFFF; + uint16_t b_high = (b >> 16) & 0xFFFF; + uint16_t result_low = add_fp8x2_e4m3(a_low, b_low); + uint16_t result_high = add_fp8x2_e4m3(a_high, b_high); + return (static_cast(result_high) << 16) | result_low; +} + +// E5M2 vectorized addition for 2 elements +MSCCLPP_DEVICE_INLINE uint16_t add_fp8x2_e5m2(uint16_t a, uint16_t b) { + float2 v; + uint32_t ival = 0; + asm volatile("v_pk_add_f32 %0, %1, %2" + : "=v"(v) + : "v"(__builtin_amdgcn_cvt_pk_f32_bf8(a, 0)), "v"(__builtin_amdgcn_cvt_pk_f32_bf8(b, 0))); + return __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, ival, false); +} + +// E5M2 vectorized addition for 4 elements +MSCCLPP_DEVICE_INLINE uint32_t add_fp8x4_e5m2(uint32_t a, uint32_t b) { + uint16_t a_low = a & 0xFFFF; + uint16_t a_high = (a >> 16) & 0xFFFF; + uint16_t b_low = b & 0xFFFF; + uint16_t b_high = (b >> 16) & 0xFFFF; + uint16_t result_low = add_fp8x2_e5m2(a_low, b_low); + uint16_t result_high = add_fp8x2_e5m2(a_high, b_high); + return (static_cast(result_high) << 16) | result_low; +} +#endif + +#if !defined(__HIP_PLATFORM_AMD__) +// CUDA platform: Template specializations for vectorized FP8 operations + +// FP8 E4M3 vectorized addition using __hadd2 for 2 elements (CUDA only) +template <> +MSCCLPP_DEVICE_INLINE __fp8x2_e4m3 add_elements(__fp8x2_e4m3 a, __fp8x2_e4m3 b) { + return __fp8x2_e4m3(__hadd2(__half2(a), __half2(b))); +} + +// FP8 E4M3 vectorized addition for 4 elements (CUDA only - via 2x __fp8x2_e4m3) +template <> +MSCCLPP_DEVICE_INLINE __fp8x4_e4m3 add_elements(__fp8x4_e4m3 a, __fp8x4_e4m3 b) { + __fp8x2_e4m3* a_pair = reinterpret_cast<__fp8x2_e4m3*>(&a); + __fp8x2_e4m3* b_pair = reinterpret_cast<__fp8x2_e4m3*>(&b); + + __fp8x2_e4m3 result[2]; + result[0] = add_elements(a_pair[0], b_pair[0]); + result[1] = add_elements(a_pair[1], b_pair[1]); + + return *reinterpret_cast<__fp8x4_e4m3*>(result); +} + +// FP8 E5M2 vectorized addition for 2 elements (CUDA only) +template <> +MSCCLPP_DEVICE_INLINE __fp8x2_e5m2 add_elements(__fp8x2_e5m2 a, __fp8x2_e5m2 b) { + return __fp8x2_e5m2(__hadd2(__half2(a), __half2(b))); +} + +// FP8 E5M2 vectorized addition for 4 elements (CUDA only - via 2x __fp8x2_e5m2) +template <> +MSCCLPP_DEVICE_INLINE __fp8x4_e5m2 add_elements(__fp8x4_e5m2 a, __fp8x4_e5m2 b) { + __fp8x2_e5m2* a_pair = reinterpret_cast<__fp8x2_e5m2*>(&a); + __fp8x2_e5m2* b_pair = reinterpret_cast<__fp8x2_e5m2*>(&b); + + __fp8x2_e5m2 result[2]; + result[0] = add_elements(a_pair[0], b_pair[0]); + result[1] = add_elements(a_pair[1], b_pair[1]); + + return *reinterpret_cast<__fp8x4_e5m2*>(result); +} +#endif +#endif // __FP8_TYPES_EXIST__ + template MSCCLPP_DEVICE_INLINE int4 add_vectors_helper(int4 a, int4 b) { int4 ret; @@ -76,6 +199,38 @@ MSCCLPP_DEVICE_INLINE int4 add_vectors<__bfloat16>(int4 a, int4 b) { return add_vectors_helper<__bfloat162>(a, b); } +#if defined(__FP8_TYPES_EXIST__) +template <> +MSCCLPP_DEVICE_INLINE int4 add_vectors<__fp8_e4m3>(int4 a, int4 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // HIP gfx942: Use helper functions that work with storage types + int4 ret; + ret.w = add_fp8x4_e4m3(a.w, b.w); + ret.x = add_fp8x4_e4m3(a.x, b.x); + ret.y = add_fp8x4_e4m3(a.y, b.y); + ret.z = add_fp8x4_e4m3(a.z, b.z); + return ret; +#else + return add_vectors_helper<__fp8x4_e4m3>(a, b); +#endif +} + +template <> +MSCCLPP_DEVICE_INLINE int4 add_vectors<__fp8_e5m2>(int4 a, int4 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // HIP gfx942: Use helper functions that work with storage types + int4 ret; + ret.w = add_fp8x4_e5m2(a.w, b.w); + ret.x = add_fp8x4_e5m2(a.x, b.x); + ret.y = add_fp8x4_e5m2(a.y, b.y); + ret.z = add_fp8x4_e5m2(a.z, b.z); + return ret; +#else + return add_vectors_helper<__fp8x4_e5m2>(a, b); +#endif +} +#endif // __FP8_TYPES_EXIST__ + template MSCCLPP_DEVICE_INLINE uint2 add_vectors_helper(uint2 a, uint2 b) { uint2 ret; @@ -99,6 +254,34 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__bfloat16>(uint return add_vectors_helper<__bfloat162>(a, b); } +#if defined(__FP8_TYPES_EXIST__) +template <> +MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__fp8_e4m3>(uint2 a, uint2 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // HIP gfx942: Use helper functions that work with storage types + uint2 ret; + ret.x = add_fp8x4_e4m3(a.x, b.x); + ret.y = add_fp8x4_e4m3(a.y, b.y); + return ret; +#else + return add_vectors_helper<__fp8x4_e4m3>(a, b); +#endif +} + +template <> +MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__fp8_e5m2>(uint2 a, uint2 b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + // HIP gfx942: Use helper functions that work with storage types + uint2 ret; + ret.x = add_fp8x4_e5m2(a.x, b.x); + ret.y = add_fp8x4_e5m2(a.y, b.y); + return ret; +#else + return add_vectors_helper<__fp8x4_e5m2>(a, b); +#endif +} +#endif // __FP8_TYPES_EXIST__ + template MSCCLPP_DEVICE_INLINE int add_vectors_helper(int a, int b) { return bit_cast(add_elements(bit_cast(a), bit_cast(b))); @@ -119,6 +302,26 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__bfloat16>(int a, return add_vectors_helper<__bfloat162>(a, b); } +#if defined(__FP8_TYPES_EXIST__) +template <> +MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__fp8_e4m3>(int a, int b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + return add_fp8x4_e4m3(a, b); +#else + return add_vectors_helper<__fp8x4_e4m3>(a, b); +#endif +} + +template <> +MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__fp8_e5m2>(int a, int b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + return add_fp8x4_e5m2(a, b); +#else + return add_vectors_helper<__fp8x4_e5m2>(a, b); +#endif +} +#endif // __FP8_TYPES_EXIST__ + template MSCCLPP_DEVICE_INLINE uint32_t add_vectors_helper(uint32_t a, uint32_t b) { return bit_cast(add_elements(bit_cast(a), bit_cast(b))); @@ -139,6 +342,26 @@ MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) { return add_vectors_helper<__bfloat162>(a, b); } +#if defined(__FP8_TYPES_EXIST__) +template <> +MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__fp8_e4m3>(uint32_t a, uint32_t b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + return add_fp8x4_e4m3(a, b); +#else + return add_vectors_helper<__fp8x4_e4m3>(a, b); +#endif +} + +template <> +MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__fp8_e5m2>(uint32_t a, uint32_t b) { +#if defined(__HIP_PLATFORM_AMD__) && defined(__gfx942__) + return add_fp8x4_e5m2(a, b); +#else + return add_vectors_helper<__fp8x4_e5m2>(a, b); +#endif +} +#endif // __FP8_TYPES_EXIST__ + #endif // MSCCLPP_DEVICE_COMPILE } // namespace @@ -945,6 +1168,30 @@ class ExecutionKernel { ); #endif break; +#if defined(__FP8_TYPES_EXIST__) + case DataType::FP8_E4M3: + executionKernel<__fp8_e4m3, PacketType, ReuseScratch><<>>( + rank, (__fp8_e4m3*)src, (__fp8_e4m3*)dst, (__fp8_e4m3*)scratch, scratchOffset, scratchChunkSize, plan, + semaphores, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif + break; + case DataType::FP8_E5M2: + executionKernel<__fp8_e5m2, PacketType, ReuseScratch><<>>( + rank, (__fp8_e5m2*)src, (__fp8_e5m2*)dst, (__fp8_e5m2*)scratch, scratchOffset, scratchChunkSize, plan, + semaphores, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif + break; +#endif // __FP8_TYPES_EXIST__ } } #else // !defined(MSCCLPP_DEVICE_HIP)