@@ -36,6 +36,32 @@ namespace tensorrt_llm
3636{
3737namespace kernels
3838{
39+
40+ namespace
41+ {
42+ // ============================================================================
43+ // Helix-specific FIFO constants
44+ // Note: Helix uses 128KB FIFO entries vs 256KB in FusedMoe
45+ // ============================================================================
46+
47+ constexpr int HELIX_FIFO_DEPTH = 4 ;
48+ constexpr int HELIX_FIFO_ENTRY_BYTES = 128 * 1024 ;
49+ constexpr int HELIX_FIFO_TOTAL_BYTES = HELIX_FIFO_ENTRY_BYTES * HELIX_FIFO_DEPTH;
50+ constexpr int HELIX_FIFO_ENTRY_128B_COUNT = HELIX_FIFO_ENTRY_BYTES / BYTES_PER_128B_BLOCK;
51+ constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof (uint64_t );
52+
53+ // ============================================================================
54+ // Implementation-only structures
55+ // ============================================================================
56+
57+ struct HelixPairInfo
58+ {
59+ int senderRank;
60+ int receiverRank;
61+ int channel;
62+ int runChannelCount;
63+ };
64+
3965// WARP_SIZE, WARP_MASK, and other constants are defined in moeCommKernelsCommon.h
4066
4167// ============================================================================
@@ -170,36 +196,36 @@ __device__ __forceinline__ uint64_t* getFifoBasePtr(HelixAllToAllParams const& p
170196 return mappedMemory + fifoOffset;
171197}
172198
173- __device__ __forceinline__ FifoInfo* getSenderFifoInfo (HelixAllToAllParams const & params, HelixPairInfo const & pairInfo)
199+ __device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo (HelixAllToAllParams const & params, HelixPairInfo const & pairInfo)
174200{
175- // SenderSideFifoInfo is physically located at sender rank
201+ // SenderSideHelixFifoInfo is physically located at sender rank
176202 int mappedMemoryRank = pairInfo.senderRank ;
177203 int rankInsideMappedMemory = pairInfo.receiverRank ;
178204
179205 auto * mappedMemory = reinterpret_cast <uint8_t *>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64 );
180206 size_t fieldOffset = static_cast <size_t >(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount ;
181207 mappedMemory += fieldOffset;
182- mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof (FifoInfo );
183- mappedMemory += pairInfo.channel * sizeof (FifoInfo );
208+ mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof (HelixFifoInfo );
209+ mappedMemory += pairInfo.channel * sizeof (HelixFifoInfo );
184210
185- return reinterpret_cast <FifoInfo *>(mappedMemory);
211+ return reinterpret_cast <HelixFifoInfo *>(mappedMemory);
186212}
187213
188- __device__ __forceinline__ FifoInfo* getReceiverFifoInfo (
214+ __device__ __forceinline__ HelixFifoInfo* getReceiverHelixFifoInfo (
189215 HelixAllToAllParams const & params, HelixPairInfo const & pairInfo)
190216{
191- // ReceiverSideFifoInfo is physically located at receiver rank
217+ // ReceiverSideHelixFifoInfo is physically located at receiver rank
192218 int mappedMemoryRank = pairInfo.receiverRank ;
193219 int rankInsideMappedMemory = pairInfo.senderRank ;
194220
195221 auto * mappedMemory = reinterpret_cast <uint8_t *>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64 );
196222 size_t fieldOffset = static_cast <size_t >(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount ;
197- fieldOffset += sizeof (FifoInfo ) * params.cpSize * params.maxChannelCount ;
223+ fieldOffset += sizeof (HelixFifoInfo ) * params.cpSize * params.maxChannelCount ;
198224 mappedMemory += fieldOffset;
199- mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof (FifoInfo );
200- mappedMemory += pairInfo.channel * sizeof (FifoInfo );
225+ mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof (HelixFifoInfo );
226+ mappedMemory += pairInfo.channel * sizeof (HelixFifoInfo );
201227
202- return reinterpret_cast <FifoInfo *>(mappedMemory);
228+ return reinterpret_cast <HelixFifoInfo *>(mappedMemory);
203229}
204230
205231__device__ __forceinline__ void startWorkspaceS2G (
@@ -315,8 +341,8 @@ __global__ void helixAllToAllKernel(HelixAllToAllParams params)
315341
316342 // Get FIFO pointers
317343 uint64_t * fifoBase = getFifoBasePtr (params, pairInfo);
318- FifoInfo * senderFifo = getSenderFifoInfo (params, pairInfo);
319- FifoInfo * receiverFifo = getReceiverFifoInfo (params, pairInfo);
344+ HelixFifoInfo * senderFifo = getSenderHelixFifoInfo (params, pairInfo);
345+ HelixFifoInfo * receiverFifo = getReceiverHelixFifoInfo (params, pairInfo);
320346
321347 int fifoEntry128ByteIndexBase = HELIX_FIFO_ENTRY_128B_COUNT;
322348 int fifoEntryIndex = -1 ;
@@ -572,6 +598,46 @@ void launchHelixAllToAllImpl(HelixAllToAllParams const& params, cudaStream_t str
572598 TLLM_CUDA_CHECK (cudaLaunchKernelEx (&config, kernel_instance, params));
573599}
574600
601+ } // anonymous namespace
602+
603+ // ============================================================================
604+ // Public API Functions
605+ // ============================================================================
606+
607+ int computeHelixMaxChannelCount (int cpSize, int smCount)
608+ {
609+ if (smCount == 0 )
610+ {
611+ int deviceId = 0 ;
612+ TLLM_CUDA_CHECK (cudaGetDevice (&deviceId));
613+ TLLM_CUDA_CHECK (cudaDeviceGetAttribute (&smCount, cudaDevAttrMultiProcessorCount, deviceId));
614+ }
615+
616+ int blockCountPerChannel = ceil_div (cpSize, MAX_GROUP_COUNT_PER_BLOCK);
617+ blockCountPerChannel *= 2 ; // for send and recv
618+
619+ int preferredChannel = smCount / blockCountPerChannel;
620+ return std::max (preferredChannel, 1 ); // at least one channel
621+ }
622+
623+ size_t computeHelixWorkspaceSizePerRank (int cpSize)
624+ {
625+ static int maxChannelCount = 0 ;
626+ if (maxChannelCount == 0 )
627+ {
628+ maxChannelCount = computeHelixMaxChannelCount (cpSize);
629+ }
630+
631+ // FIFO buffers: cpSize * channelCount pairs
632+ size_t fifoSize = static_cast <size_t >(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
633+
634+ // Sender and receiver FIFO info structures
635+ size_t senderInfoSize = sizeof (HelixFifoInfo) * cpSize * maxChannelCount;
636+ size_t receiverInfoSize = sizeof (HelixFifoInfo) * cpSize * maxChannelCount;
637+
638+ return fifoSize + senderInfoSize + receiverInfoSize;
639+ }
640+
575641void launchHelixAllToAll (HelixAllToAllParams const & params, bool allowVariableField1, cudaStream_t stream)
576642{
577643 if (allowVariableField1)
@@ -593,8 +659,8 @@ void initializeHelixWorkspace(uint64_t* local_workspace_ptr, int cpSize, cudaStr
593659 int maxChannelCount = computeHelixMaxChannelCount (cpSize);
594660 // Calculate sizes with channel dimension
595661 size_t fifoSize = static_cast <size_t >(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
596- size_t senderInfoSize = sizeof (FifoInfo ) * cpSize * maxChannelCount;
597- size_t receiverInfoSize = sizeof (FifoInfo ) * cpSize * maxChannelCount;
662+ size_t senderInfoSize = sizeof (HelixFifoInfo ) * cpSize * maxChannelCount;
663+ size_t receiverInfoSize = sizeof (HelixFifoInfo ) * cpSize * maxChannelCount;
598664
599665 // Initialize FIFO buffers to 0xFFFFFFFF (-1 for signed integer types)
600666 TLLM_CUDA_CHECK (cudaMemsetAsync (local_workspace_ptr, 0xFF , fifoSize, stream));
0 commit comments