Skip to content

Commit 10bc185

Browse files
committed
refactor and remove duplication
1 parent 812770c commit 10bc185

File tree

4 files changed

+95
-256
lines changed

4 files changed

+95
-256
lines changed

cpp/tensorrt_llm/kernels/helixAllToAll.cu

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,32 @@ namespace tensorrt_llm
3636
{
3737
namespace kernels
3838
{
39+
40+
namespace
41+
{
42+
// ============================================================================
43+
// Helix-specific FIFO constants
44+
// Note: Helix uses 128KB FIFO entries vs 256KB in FusedMoe
45+
// ============================================================================
46+
47+
constexpr int HELIX_FIFO_DEPTH = 4;
48+
constexpr int HELIX_FIFO_ENTRY_BYTES = 128 * 1024;
49+
constexpr int HELIX_FIFO_TOTAL_BYTES = HELIX_FIFO_ENTRY_BYTES * HELIX_FIFO_DEPTH;
50+
constexpr int HELIX_FIFO_ENTRY_128B_COUNT = HELIX_FIFO_ENTRY_BYTES / BYTES_PER_128B_BLOCK;
51+
constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof(uint64_t);
52+
53+
// ============================================================================
54+
// Implementation-only structures
55+
// ============================================================================
56+
57+
struct HelixPairInfo
58+
{
59+
int senderRank;
60+
int receiverRank;
61+
int channel;
62+
int runChannelCount;
63+
};
64+
3965
// WARP_SIZE, WARP_MASK, and other constants are defined in moeCommKernelsCommon.h
4066

4167
// ============================================================================
@@ -170,36 +196,36 @@ __device__ __forceinline__ uint64_t* getFifoBasePtr(HelixAllToAllParams const& p
170196
return mappedMemory + fifoOffset;
171197
}
172198

173-
__device__ __forceinline__ FifoInfo* getSenderFifoInfo(HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
199+
__device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo(HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
174200
{
175-
// SenderSideFifoInfo is physically located at sender rank
201+
// SenderSideHelixFifoInfo is physically located at sender rank
176202
int mappedMemoryRank = pairInfo.senderRank;
177203
int rankInsideMappedMemory = pairInfo.receiverRank;
178204

179205
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
180206
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
181207
mappedMemory += fieldOffset;
182-
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(FifoInfo);
183-
mappedMemory += pairInfo.channel * sizeof(FifoInfo);
208+
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
209+
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
184210

185-
return reinterpret_cast<FifoInfo*>(mappedMemory);
211+
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
186212
}
187213

188-
__device__ __forceinline__ FifoInfo* getReceiverFifoInfo(
214+
__device__ __forceinline__ HelixFifoInfo* getReceiverHelixFifoInfo(
189215
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
190216
{
191-
// ReceiverSideFifoInfo is physically located at receiver rank
217+
// ReceiverSideHelixFifoInfo is physically located at receiver rank
192218
int mappedMemoryRank = pairInfo.receiverRank;
193219
int rankInsideMappedMemory = pairInfo.senderRank;
194220

195221
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
196222
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
197-
fieldOffset += sizeof(FifoInfo) * params.cpSize * params.maxChannelCount;
223+
fieldOffset += sizeof(HelixFifoInfo) * params.cpSize * params.maxChannelCount;
198224
mappedMemory += fieldOffset;
199-
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(FifoInfo);
200-
mappedMemory += pairInfo.channel * sizeof(FifoInfo);
225+
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
226+
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
201227

202-
return reinterpret_cast<FifoInfo*>(mappedMemory);
228+
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
203229
}
204230

205231
__device__ __forceinline__ void startWorkspaceS2G(
@@ -315,8 +341,8 @@ __global__ void helixAllToAllKernel(HelixAllToAllParams params)
315341

316342
// Get FIFO pointers
317343
uint64_t* fifoBase = getFifoBasePtr(params, pairInfo);
318-
FifoInfo* senderFifo = getSenderFifoInfo(params, pairInfo);
319-
FifoInfo* receiverFifo = getReceiverFifoInfo(params, pairInfo);
344+
HelixFifoInfo* senderFifo = getSenderHelixFifoInfo(params, pairInfo);
345+
HelixFifoInfo* receiverFifo = getReceiverHelixFifoInfo(params, pairInfo);
320346

321347
int fifoEntry128ByteIndexBase = HELIX_FIFO_ENTRY_128B_COUNT;
322348
int fifoEntryIndex = -1;
@@ -572,6 +598,46 @@ void launchHelixAllToAllImpl(HelixAllToAllParams const& params, cudaStream_t str
572598
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params));
573599
}
574600

601+
} // anonymous namespace
602+
603+
// ============================================================================
604+
// Public API Functions
605+
// ============================================================================
606+
607+
int computeHelixMaxChannelCount(int cpSize, int smCount)
608+
{
609+
if (smCount == 0)
610+
{
611+
int deviceId = 0;
612+
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
613+
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
614+
}
615+
616+
int blockCountPerChannel = ceil_div(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
617+
blockCountPerChannel *= 2; // for send and recv
618+
619+
int preferredChannel = smCount / blockCountPerChannel;
620+
return std::max(preferredChannel, 1); // at least one channel
621+
}
622+
623+
size_t computeHelixWorkspaceSizePerRank(int cpSize)
624+
{
625+
static int maxChannelCount = 0;
626+
if (maxChannelCount == 0)
627+
{
628+
maxChannelCount = computeHelixMaxChannelCount(cpSize);
629+
}
630+
631+
// FIFO buffers: cpSize * channelCount pairs
632+
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
633+
634+
// Sender and receiver FIFO info structures
635+
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
636+
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
637+
638+
return fifoSize + senderInfoSize + receiverInfoSize;
639+
}
640+
575641
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream)
576642
{
577643
if (allowVariableField1)
@@ -593,8 +659,8 @@ void initializeHelixWorkspace(uint64_t* local_workspace_ptr, int cpSize, cudaStr
593659
int maxChannelCount = computeHelixMaxChannelCount(cpSize);
594660
// Calculate sizes with channel dimension
595661
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
596-
size_t senderInfoSize = sizeof(FifoInfo) * cpSize * maxChannelCount;
597-
size_t receiverInfoSize = sizeof(FifoInfo) * cpSize * maxChannelCount;
662+
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
663+
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
598664

599665
// Initialize FIFO buffers to 0xFFFFFFFF (-1 for signed integer types)
600666
TLLM_CUDA_CHECK(cudaMemsetAsync(local_workspace_ptr, 0xFF, fifoSize, stream));

cpp/tensorrt_llm/kernels/helixAllToAll.h

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,6 @@ namespace tensorrt_llm
2929
namespace kernels
3030
{
3131

32-
// ============================================================================
33-
// Helix-specific FIFO constants
34-
// Note: Helix uses 128KB FIFO entries vs 256KB in FusedMoe
35-
// ============================================================================
36-
37-
constexpr int HELIX_FIFO_DEPTH = 4;
38-
constexpr int HELIX_FIFO_ENTRY_BYTES = 128 * 1024;
39-
constexpr int HELIX_FIFO_ENTRY_128B_COUNT = HELIX_FIFO_ENTRY_BYTES / BYTES_PER_128B_BLOCK;
40-
constexpr int HELIX_FIFO_TOTAL_BYTES = HELIX_FIFO_ENTRY_BYTES * HELIX_FIFO_DEPTH;
41-
constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof(uint64_t);
42-
4332
// Backward compatibility aliases (reference common constants from moeCommKernelsCommon.h)
4433
// WARP_SIZE, WARP_MASK, BYTES_PER_128B_BLOCK, UINT64_PER_128B_BLOCK, MAX_GROUP_COUNT_PER_BLOCK
4534
// are now defined in moeCommKernelsCommon.h
@@ -48,17 +37,9 @@ constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof(uint64_t);
4837
// Structure declarations and definitions
4938
// ============================================================================
5039

51-
struct HelixPairInfo
52-
{
53-
int senderRank;
54-
int receiverRank;
55-
int channel;
56-
int runChannelCount;
57-
};
58-
5940
// ALIGN_256 is defined in moeCommKernelsCommon.h
6041

61-
struct ALIGN_256 FifoInfo
42+
struct ALIGN_256 HelixFifoInfo
6243
{
6344
volatile int64_t head;
6445
volatile int64_t tail;
@@ -102,45 +83,15 @@ struct HelixAllToAllParams
10283
* @param smCount Number of SMs available (0 = auto-detect)
10384
* @return Number of channels to use
10485
*/
105-
inline int computeHelixMaxChannelCount(int cpSize, int smCount = 0)
106-
{
107-
if (smCount == 0)
108-
{
109-
int deviceId = 0;
110-
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
111-
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
112-
}
113-
114-
int blockCountPerChannel = ceil_div(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
115-
blockCountPerChannel *= 2; // for send and recv
116-
117-
int preferredChannel = smCount / blockCountPerChannel;
118-
return std::max(preferredChannel, 1); // at least one channel
119-
}
86+
int computeHelixMaxChannelCount(int cpSize, int smCount = 0);
12087

12188
/**
12289
* Compute the workspace size required per rank for the all-to-all operation.
12390
*
12491
* @param cpSize Number of context parallel ranks
12592
* @return Size in bytes
12693
*/
127-
inline size_t computeHelixWorkspaceSizePerRank(int cpSize)
128-
{
129-
static int maxChannelCount = 0;
130-
if (maxChannelCount == 0)
131-
{
132-
maxChannelCount = computeHelixMaxChannelCount(cpSize);
133-
}
134-
135-
// FIFO buffers: cpSize * channelCount pairs
136-
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
137-
138-
// Sender and receiver FIFO info structures
139-
size_t senderInfoSize = sizeof(FifoInfo) * cpSize * maxChannelCount;
140-
size_t receiverInfoSize = sizeof(FifoInfo) * cpSize * maxChannelCount;
141-
142-
return fifoSize + senderInfoSize + receiverInfoSize;
143-
}
94+
size_t computeHelixWorkspaceSizePerRank(int cpSize);
14495

14596
/**
14697
* Initialize workspace memory for a given rank.

cpp/tensorrt_llm/thop/alltoallOp.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,20 @@
1515
* limitations under the License.
1616
*/
1717

18-
#include "alltoallOp.h"
18+
// #include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
19+
20+
// #include <c10/cuda/CUDAStream.h>
21+
// #include <torch/extension.h>
22+
// #include <vector>
23+
24+
25+
26+
#include "tensorrt_llm/kernels/helixAllToAll.h"
27+
#include "tensorrt_llm/kernels/fusedMoeCommKernels.h"
1928
#include "tensorrt_llm/common/opUtils.h"
2029
#include "tensorrt_llm/runtime/torchUtils.h"
2130
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
22-
#include "thUtils.h"
31+
#include "tensorrt_llm/thop/thUtils.h"
2332

2433
#include <NvInferRuntime.h>
2534
#include <c10/cuda/CUDAStream.h>

0 commit comments

Comments
 (0)