forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelixAllToAll.cu
More file actions
683 lines (578 loc) · 26.4 KB
/
helixAllToAll.cu
File metadata and controls
683 lines (578 loc) · 26.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/cudaAsyncOps.cuh"
#include "tensorrt_llm/kernels/helixAllToAll.h"
#include "tensorrt_llm/kernels/ll128Proto.cuh"
#include "tensorrt_llm/kernels/moeCommKernelsCommon.h"
#include <algorithm>
#include <tuple>
#include <unordered_map>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
namespace
{
// ============================================================================
// Structure declarations and definitions
// ============================================================================
// ALIGN_256 is defined in moeCommKernelsCommon.h
struct ALIGN_256 HelixFifoInfo
{
volatile int64_t head;
volatile int64_t tail;
};
// ============================================================================
// Helix-specific FIFO constants
// Note: Helix uses 128KB FIFO entries vs 256KB in FusedMoe
// ============================================================================
constexpr int HELIX_FIFO_DEPTH = 4;
constexpr int HELIX_FIFO_ENTRY_BYTES = 128 * 1024;
constexpr int HELIX_FIFO_TOTAL_BYTES = HELIX_FIFO_ENTRY_BYTES * HELIX_FIFO_DEPTH;
constexpr int HELIX_FIFO_ENTRY_128B_COUNT = HELIX_FIFO_ENTRY_BYTES / BYTES_PER_128B_BLOCK;
constexpr int HELIX_FIFO_TOTAL_U64 = HELIX_FIFO_TOTAL_BYTES / sizeof(uint64_t);
// ============================================================================
// Implementation-only structures
// ============================================================================
struct HelixPairInfo
{
int senderRank;
int receiverRank;
int channel;
int runChannelCount;
};
// WARP_SIZE, WARP_MASK, and other constants are defined in moeCommKernelsCommon.h
// ============================================================================
// Helper Functions
// ============================================================================
__host__ __device__ inline int getFieldSize(HelixFieldInfo const& fieldInfo)
{
return fieldInfo.elementCount * fieldInfo.elementSize;
}
__host__ __device__ inline uint8_t* getPtr(HelixFieldInfo const& fieldInfo, int blockIdx)
{
return fieldInfo.dataPtr + blockIdx * fieldInfo.stride;
}
__device__ __forceinline__ void waitG2sAllFields(uint64_t* smemBar, uint32_t* phaseParity)
{
cp_async_wait_group<0>();
smemBarWait(smemBar, phaseParity);
}
// Align size to 128 bytes
__host__ __device__ __forceinline__ int align128(int size)
{
return align_up(size, BYTES_PER_128B_BLOCK);
}
// ============================================================================
// G2S (Global to Shared) Operations
// ============================================================================
__device__ __forceinline__ void g2sField(
HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, uint64_t* smemBar, int laneId)
{
int copySize = getFieldSize(fieldInfo);
if (copySize > 0 && laneId == 0)
{
uint8_t* srcPtr = getPtr(fieldInfo, dataIndex);
uint8_t* dstPtr = shmemBase + shmemOffset;
cp_async_bulk_g2s(dstPtr, srcPtr, copySize, smemBar);
}
}
template <bool ALLOW_VARIABLE_FIELD1>
__device__ __forceinline__ int g2sAllFields(
HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, uint64_t* smemBar, int laneId)
{
int totalSize = 0;
// Load field 0 (variable size half)
g2sField(fieldInfo[0], dataIndex, shmemBase, 0, smemBar, laneId);
int field0Size = getFieldSize(fieldInfo[0]);
totalSize += field0Size;
// Load field 1 (single float2)
if constexpr (ALLOW_VARIABLE_FIELD1)
{
g2sField(fieldInfo[1], dataIndex, shmemBase, totalSize, smemBar, laneId);
totalSize += getFieldSize(fieldInfo[1]);
}
else
{
ldgsts<8>(reinterpret_cast<int*>(shmemBase + totalSize),
reinterpret_cast<int const*>(getPtr(fieldInfo[1], dataIndex)), laneId == 0);
cp_async_commit_group();
}
return totalSize;
}
// ============================================================================
// S2G (Shared to Global) Operations
// ============================================================================
__device__ __forceinline__ void s2gField(
HelixFieldInfo const& fieldInfo, int dataIndex, uint8_t* shmemBase, int shmemOffset, int laneId)
{
int copySize = getFieldSize(fieldInfo);
if (copySize > 0 && laneId == 0)
{
uint8_t* srcPtr = shmemBase + shmemOffset;
uint8_t* dstPtr = getPtr(fieldInfo, dataIndex);
cp_async_bulk_s2g(dstPtr, srcPtr, copySize);
}
}
template <bool ALLOW_VARIABLE_FIELD1>
__device__ __forceinline__ void s2gAllFields(
HelixFieldInfo const* fieldInfo, int dataIndex, uint8_t* shmemBase, int laneId)
{
int offset = 0;
// Store field 0 (variable size half)
s2gField(fieldInfo[0], dataIndex, shmemBase, offset, laneId);
int field0Size = getFieldSize(fieldInfo[0]);
offset += field0Size;
// Store field 1 (single float2)
if constexpr (ALLOW_VARIABLE_FIELD1)
{
s2gField(fieldInfo[1], dataIndex, shmemBase, offset, laneId);
offset += getFieldSize(fieldInfo[1]);
}
else
{
if (laneId == 0)
{
auto* srcPtr = reinterpret_cast<float2*>(reinterpret_cast<uint8_t*>(shmemBase) + offset);
auto* dstPtr = reinterpret_cast<float2*>(getPtr(fieldInfo[1], dataIndex));
dstPtr[0] = srcPtr[0];
}
}
cp_async_bulk_commit_group();
}
// ============================================================================
// Workspace FIFO Operations
// ============================================================================
__device__ __forceinline__ uint64_t* getFifoBasePtr(HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
{
// FIFO is physically located at receiver rank
int mappedMemoryRank = pairInfo.receiverRank;
int rankInsideMappedMemory = pairInfo.senderRank;
auto* mappedMemory = params.workspace + mappedMemoryRank * params.workspaceStrideInU64;
// Navigate to the right FIFO: [peer_rank][channel]
size_t fifoOffset = rankInsideMappedMemory * params.maxChannelCount * HELIX_FIFO_TOTAL_U64;
fifoOffset += pairInfo.channel * HELIX_FIFO_TOTAL_U64;
return mappedMemory + fifoOffset;
}
__device__ __forceinline__ HelixFifoInfo* getSenderHelixFifoInfo(
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
{
// SenderSideHelixFifoInfo is physically located at sender rank
int mappedMemoryRank = pairInfo.senderRank;
int rankInsideMappedMemory = pairInfo.receiverRank;
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
mappedMemory += fieldOffset;
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
}
__device__ __forceinline__ HelixFifoInfo* getReceiverHelixFifoInfo(
HelixAllToAllParams const& params, HelixPairInfo const& pairInfo)
{
// ReceiverSideHelixFifoInfo is physically located at receiver rank
int mappedMemoryRank = pairInfo.receiverRank;
int rankInsideMappedMemory = pairInfo.senderRank;
auto* mappedMemory = reinterpret_cast<uint8_t*>(params.workspace + mappedMemoryRank * params.workspaceStrideInU64);
size_t fieldOffset = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * params.cpSize * params.maxChannelCount;
fieldOffset += sizeof(HelixFifoInfo) * params.cpSize * params.maxChannelCount;
mappedMemory += fieldOffset;
mappedMemory += rankInsideMappedMemory * params.maxChannelCount * sizeof(HelixFifoInfo);
mappedMemory += pairInfo.channel * sizeof(HelixFifoInfo);
return reinterpret_cast<HelixFifoInfo*>(mappedMemory);
}
__device__ __forceinline__ void startWorkspaceS2G(
uint64_t* fifoEntry, uint8_t* shmemBase, int send128ByteCount, int fifo128ByteOffset, int laneId)
{
int copyByteCount = send128ByteCount * BYTES_PER_128B_BLOCK;
if (laneId == 0)
{
cp_async_bulk_s2g(
fifoEntry + fifo128ByteOffset * BYTES_PER_128B_BLOCK / sizeof(uint64_t), shmemBase, copyByteCount);
}
cp_async_bulk_commit_group();
}
__device__ __forceinline__ void startWorkspaceS2GReg(
uint64_t* fifoEntry, uint8_t* sharedMemoryBase, int send128ByteCount, int fifo128ByteOffset, int laneId)
{
int copyInt4Count = send128ByteCount * BYTES_PER_128B_BLOCK / sizeof(int4);
int4* sharedMemoryInt4 = reinterpret_cast<int4*>(sharedMemoryBase);
uint64_t* fifoPtr = fifoEntry + fifo128ByteOffset * UINT64_PER_128B_BLOCK;
int4* fifoPtrInt4 = reinterpret_cast<int4*>(fifoPtr);
#pragma unroll 4
for (int i = laneId; i < copyInt4Count; i += WARP_SIZE)
{
fifoPtrInt4[i] = sharedMemoryInt4[i];
}
}
__device__ __forceinline__ uint64_t startWorkspaceG2S(uint8_t* shmemBase, uint64_t* fifoEntry, int allLoad128ByteCount,
int fifo128ByteOffset, int loaded128ByteCount, uint64_t* smemBar, int laneId)
{
int copyByteCount = (allLoad128ByteCount - loaded128ByteCount) * BYTES_PER_128B_BLOCK;
if (laneId == 0)
{
cp_async_bulk_g2s(shmemBase + loaded128ByteCount * BYTES_PER_128B_BLOCK,
fifoEntry + (fifo128ByteOffset + loaded128ByteCount) * UINT64_PER_128B_BLOCK, copyByteCount, smemBar);
}
return mbarrier_arrive_expect_tx(smemBar, laneId == 0 ? copyByteCount : 0);
}
// LL128Proto is now defined in ll128Proto.cuh
// ============================================================================
// Size helpers
// ============================================================================
// Compute total size needed for both fields
__host__ __device__ __forceinline__ int computeTotalUnpackedSize(HelixFieldInfo const* fields)
{
int size = 0;
// Field 0: note it must be aligned to 16 bytes
size += align_up(getFieldSize(fields[0]), 16);
// Field 1: single float2
size += align_up(getFieldSize(fields[1]), 16);
return align128(size);
}
__host__ __device__ __forceinline__ int computeTotalPackedSize(HelixFieldInfo const* fields)
{
// because field 0 must be aligned to 16 bytes, this is the same as unpacked
return computeTotalUnpackedSize(fields);
}
__host__ __device__ __forceinline__ int computeProtoTransferSize(HelixFieldInfo const* fields)
{
return LL128Proto::computeProtoTransfer128ByteAlignedSize(computeTotalPackedSize(fields));
}
// ============================================================================
// Main All-to-All Kernel
// ============================================================================
template <bool ALLOW_VARIABLE_FIELD1>
__global__ void helixAllToAllKernel(HelixAllToAllParams params)
{
extern __shared__ uint8_t allWarpShmem[];
__shared__ uint64_t allWarpSmemBar[MAX_GROUP_COUNT_PER_BLOCK];
bool isSender = (blockIdx.z == 0);
// Each warp is a group handling a different peer rank
int group = __shfl_sync(WARP_MASK, threadIdx.y, 0);
int laneId = threadIdx.x % WARP_SIZE;
int runChannelCount = gridDim.y;
// Compute peer rank: blockIdx.x determines which set of peers, group
// determines which peer in that set
int peerRank = blockIdx.x * blockDim.y + group;
if (peerRank >= params.cpSize)
{
return;
}
// Setup pair info for this communication
HelixPairInfo pairInfo;
pairInfo.channel = blockIdx.y;
pairInfo.runChannelCount = runChannelCount;
pairInfo.senderRank = isSender ? params.cpRank : peerRank;
pairInfo.receiverRank = isSender ? peerRank : params.cpRank;
// Initialize barrier for this group
initSmemBar(&allWarpSmemBar[group], laneId);
uint32_t phaseParity = 0;
// Get shared memory for this group
int singlePackedSize = computeTotalPackedSize(params.sendFields);
int singlePacked128ByteCount = singlePackedSize / BYTES_PER_128B_BLOCK;
int singleUnpackedSize = computeTotalUnpackedSize(params.sendFields);
int singleProtoTransferSize = computeProtoTransferSize(params.sendFields);
int singleProtoTransfer128ByteCount = singleProtoTransferSize / BYTES_PER_128B_BLOCK;
int singleShmSize = std::max(singleUnpackedSize, singleProtoTransferSize);
uint8_t* shmem = allWarpShmem + group * singleShmSize;
// Get FIFO pointers
uint64_t* fifoBase = getFifoBasePtr(params, pairInfo);
HelixFifoInfo* senderFifo = getSenderHelixFifoInfo(params, pairInfo);
HelixFifoInfo* receiverFifo = getReceiverHelixFifoInfo(params, pairInfo);
int fifoEntry128ByteIndexBase = HELIX_FIFO_ENTRY_128B_COUNT;
int fifoEntryIndex = -1;
// regardless of sender or receiver, we wait for the previous kernel here
// receiver blocks do not need to wait at all, but they should not start
// to stress the memory system regardless
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (isSender)
{
// sender blocks should trigger next kernel immediately, s.t. they
// do not block the next kernel from starting
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Sender logic: send data from cpRank's slice to peerRank
int64_t head = senderFifo->head;
int64_t tail = senderFifo->tail;
// Each channel processes entries with stride
// Start at channel index, increment by total channel count
for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount)
{
// dataIndex points to the data for peerRank in this entry
int dataIndex = entryIdx * params.cpSize + peerRank;
// Load data from global to shared, then arrive on barrier
int loadedSize = g2sAllFields<ALLOW_VARIABLE_FIELD1>(
params.sendFields, dataIndex, shmem, &allWarpSmemBar[group], laneId);
uint64_t arriveState = mbarrier_arrive_expect_tx(&allWarpSmemBar[group], laneId == 0 ? loadedSize : 0);
// update FIFO entry index and head if needed
if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT)
{
if (fifoEntryIndex >= 0)
{
head++;
__syncwarp();
senderFifo->head = head;
}
fifoEntryIndex = head % HELIX_FIFO_DEPTH;
fifoEntry128ByteIndexBase = 0;
while (tail + HELIX_FIFO_DEPTH <= head)
{
tail = senderFifo->tail;
}
__syncwarp();
}
// wait for data to be loaded into shared memory
waitG2sAllFields(&allWarpSmemBar[group], &phaseParity);
// note: we don't need to pack anything, fields are already packed in
// shared memory
LL128Proto::protoPack(shmem, head, singlePacked128ByteCount, fifoEntry128ByteIndexBase, laneId);
uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t));
// Copy from shared to workspace FIFO
startWorkspaceS2GReg(fifoEntry, shmem, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, laneId);
fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount;
// ensure that we can over-write shmem in next iteration
// (it must be fully read by all threads when doing S2G above)
__syncwarp();
}
if (fifoEntry128ByteIndexBase > 0)
{
head++;
senderFifo->head = head;
}
}
else
{
// Receiver logic: receive data from peerRank to cpRank's slice
int64_t tail = receiverFifo->tail;
bool needRelease = false;
// Each channel processes entries with stride
// Start at channel index, increment by total channel count
for (int entryIdx = pairInfo.channel; entryIdx < params.entryCount; entryIdx += runChannelCount)
{
// receiver blocks should trigger next kernel at last iteration
// note: some blocks might not even go into this for-loop, but they
// would exit which is equivalent to the pre-exit trigger
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
if (entryIdx + runChannelCount >= params.entryCount)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// dataIndex points to where we receive data from peerRank in this entry
int dataIndex = entryIdx * params.cpSize + peerRank;
int loaded128ByteCount = 0;
if (fifoEntry128ByteIndexBase + singleProtoTransfer128ByteCount > HELIX_FIFO_ENTRY_128B_COUNT)
{
if (fifoEntryIndex >= 0)
{
tail++;
needRelease = true;
}
fifoEntryIndex = tail % HELIX_FIFO_DEPTH;
fifoEntry128ByteIndexBase = 0;
// receiver doesn't need to wait on FIFO entry being readable: it's
// always readable
__syncwarp();
}
uint64_t* fifoEntry = fifoBase + fifoEntryIndex * (HELIX_FIFO_ENTRY_BYTES / sizeof(uint64_t));
while (loaded128ByteCount < singleProtoTransfer128ByteCount)
{
startWorkspaceG2S(shmem, fifoEntry, singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase,
loaded128ByteCount, &allWarpSmemBar[group], laneId);
if (needRelease)
{
receiverFifo->tail = tail;
senderFifo->tail = tail;
needRelease = false;
}
smemBarWait(&allWarpSmemBar[group], &phaseParity);
loaded128ByteCount += LL128Proto::template checkDataReceivedInShm<false>(shmem, tail,
singleProtoTransfer128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
}
LL128Proto::protoUnpack(
shmem, tail, singlePacked128ByteCount, fifoEntry128ByteIndexBase, loaded128ByteCount, laneId);
// note: fields are already unpacked in shared memory
s2gAllFields<ALLOW_VARIABLE_FIELD1>(params.recvFields, dataIndex, shmem, laneId);
// wait for data to be read from shared memory
cp_async_bulk_wait_group_read<0>();
// note: LL128Proto doesn't need rearm
// rearmFifoBuffer();
fifoEntry128ByteIndexBase += singleProtoTransfer128ByteCount;
}
if (fifoEntry128ByteIndexBase > 0)
{
tail++;
receiverFifo->tail = tail;
senderFifo->tail = tail;
}
}
}
// ============================================================================
// Compute actual channel count
// ============================================================================
struct hash_cache_key
{
size_t operator()(std::tuple<int, int, int> const& x) const
{
return std::get<0>(x) ^ std::get<1>(x) ^ std::get<2>(x);
}
};
template <bool ALLOW_VARIABLE_FIELD1>
std::tuple<int, int, int> computeChannelAndGroupCount(int cpSize, HelixFieldInfo const* fields)
{
static std::unordered_map<std::tuple<int, int, int>, std::tuple<int, int, int>, hash_cache_key> cache;
int deviceId = 0;
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
int singleShmSize = std::max(computeTotalUnpackedSize(fields), computeProtoTransferSize(fields));
auto key = std::make_tuple(deviceId, cpSize, singleShmSize);
auto it = cache.find(key);
if (it != cache.end())
{
return it->second;
}
int maxGroupCountPerCta = std::min(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
int groupCountPerCta = maxGroupCountPerCta; // Start with max
int totalDynamicShmemSize = singleShmSize * groupCountPerCta;
int maxDynamicShmSize = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&maxDynamicShmSize, cudaDevAttrMaxSharedMemoryPerBlockOptin, deviceId));
while (totalDynamicShmemSize > maxDynamicShmSize)
{
groupCountPerCta--;
totalDynamicShmemSize = singleShmSize * groupCountPerCta;
}
TLLM_CHECK_WITH_INFO(totalDynamicShmemSize <= maxDynamicShmSize, "Single packed size %d exceeds limit %d",
singleShmSize, maxDynamicShmSize);
// Set shared memory attribute if needed
if (totalDynamicShmemSize > 48 * 1024)
{
TLLM_CUDA_CHECK(cudaFuncSetAttribute(helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>,
cudaFuncAttributeMaxDynamicSharedMemorySize, totalDynamicShmemSize));
}
int blockCountPerChannel = ceil_div(cpSize, groupCountPerCta);
blockCountPerChannel *= 2; // for send and recv
int smCount = 0;
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
// TODO: we might only want to use half the SMs to overlap with other kernels.
// note that overlap with FMHA is almost impossible because it must use
// all SMs and probably uses >50% shmem per SM.
// overlap with the subsequent BMM / out proj GEMMs might be possible,
// so we need experiments to see whether it makes sense.
int channelCount = std::max(smCount / blockCountPerChannel, 1);
auto value = std::make_tuple(channelCount, groupCountPerCta, totalDynamicShmemSize);
cache[key] = value;
return value;
}
// ============================================================================
// Host Launch Function
// ============================================================================
template <bool ALLOW_VARIABLE_FIELD1>
void launchHelixAllToAllImpl(HelixAllToAllParams const& params, cudaStream_t stream)
{
int maxChannelCount = computeHelixMaxChannelCount(params.cpSize);
TLLM_CHECK_WITH_INFO(params.maxChannelCount == maxChannelCount,
"maxChannelCount %d does not match computed maxChannelCount %d", params.maxChannelCount, maxChannelCount);
auto [channelCount, groupCountPerCta, totalDynamicShmemSize]
= computeChannelAndGroupCount<ALLOW_VARIABLE_FIELD1>(params.cpSize, params.sendFields);
if (params.channelCount > 0)
{
channelCount = params.channelCount;
TLLM_CHECK_WITH_INFO(channelCount <= maxChannelCount, "channelCount %d exceeds maxChannelCount %d",
channelCount, maxChannelCount);
}
// Compute grid dimensions
// grid.x = blocks per channel (how many blocks needed to cover all peer
// ranks) grid.y = number of channels (parallel channels) grid.z = 2 (sender
// and receiver)
int ctaPerChannel = ceil_div(params.cpSize, groupCountPerCta);
auto* kernel_instance = &helixAllToAllKernel<ALLOW_VARIABLE_FIELD1>;
cudaLaunchConfig_t config;
config.gridDim = dim3(ctaPerChannel, channelCount, 2);
config.blockDim = dim3(WARP_SIZE, groupCountPerCta);
config.dynamicSmemBytes = totalDynamicShmemSize;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params));
}
} // anonymous namespace
// ============================================================================
// Public API Functions
// ============================================================================
int computeHelixMaxChannelCount(int cpSize, int smCount)
{
if (smCount == 0)
{
int deviceId = 0;
TLLM_CUDA_CHECK(cudaGetDevice(&deviceId));
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&smCount, cudaDevAttrMultiProcessorCount, deviceId));
}
int blockCountPerChannel = ceil_div(cpSize, MAX_GROUP_COUNT_PER_BLOCK);
blockCountPerChannel *= 2; // for send and recv
int preferredChannel = smCount / blockCountPerChannel;
return std::max(preferredChannel, 1); // at least one channel
}
size_t computeHelixWorkspaceSizePerRank(int cpSize)
{
static int maxChannelCount = 0;
if (maxChannelCount == 0)
{
maxChannelCount = computeHelixMaxChannelCount(cpSize);
}
// FIFO buffers: cpSize * channelCount pairs
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
// Sender and receiver FIFO info structures
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
return fifoSize + senderInfoSize + receiverInfoSize;
}
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream)
{
if (allowVariableField1)
{
launchHelixAllToAllImpl<true>(params, stream);
}
else
{
launchHelixAllToAllImpl<false>(params, stream);
}
}
// ============================================================================
// Workspace Initialization
// ============================================================================
void initializeHelixWorkspace(uint64_t* local_workspace_ptr, int cpSize, cudaStream_t stream)
{
int maxChannelCount = computeHelixMaxChannelCount(cpSize);
// Calculate sizes with channel dimension
size_t fifoSize = static_cast<size_t>(HELIX_FIFO_TOTAL_BYTES) * cpSize * maxChannelCount;
size_t senderInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
size_t receiverInfoSize = sizeof(HelixFifoInfo) * cpSize * maxChannelCount;
// Initialize FIFO buffers to 0xFFFFFFFF (-1 for signed integer types)
TLLM_CUDA_CHECK(cudaMemsetAsync(local_workspace_ptr, 0xFF, fifoSize, stream));
// Initialize sender and receiver info to zero (single call for both)
uint8_t* infoPtr = reinterpret_cast<uint8_t*>(local_workspace_ptr) + fifoSize;
TLLM_CUDA_CHECK(cudaMemsetAsync(infoPtr, 0, senderInfoSize + receiverInfoSize, stream));
}
} // namespace kernels
TRTLLM_NAMESPACE_END