Skip to content

Commit 1653599

Browse files
authored
feat: Add MNNVL MoE A2A support (NVIDIA#3504)
* add MNNVL memory mapping support Signed-off-by: Dongxu Yang <[email protected]> * add more MPI environment for trtllm-llmapi-launch Signed-off-by: Dongxu Yang <[email protected]> * add MoE communication and prepare kernels Signed-off-by: Dongxu Yang <[email protected]> * add MNNVL AlltoAll support for DeepSeekV3 Signed-off-by: Dongxu Yang <[email protected]> * add output dump for throughput benchmark Signed-off-by: Dongxu Yang <[email protected]> * support dynamic kernel launch grid Signed-off-by: Dongxu Yang <[email protected]> * address review comments Signed-off-by: Dongxu Yang <[email protected]> * address review comments #2 Signed-off-by: Dongxu Yang <[email protected]> --------- Signed-off-by: Dongxu Yang <[email protected]>
1 parent 5794420 commit 1653599

File tree

14 files changed

+2800
-37
lines changed

14 files changed

+2800
-37
lines changed

cpp/tensorrt_llm/kernels/moeCommKernels.cu

Lines changed: 769 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
/*
2+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include <map>
20+
21+
#include "tensorrt_llm/common/cudaUtils.h"
22+
23+
namespace tensorrt_llm::kernels
24+
{
25+
26+
#ifdef __CUDACC__
27+
#define ALIGN_256 __align__(256)
28+
#else
29+
#define ALIGN_256 alignas(256)
30+
#endif
31+
32+
struct ALIGN_256 MoeCommFifoConnInfo
33+
{
34+
volatile uint64_t head; // write position
35+
volatile uint64_t tail; // read position
36+
};
37+
38+
constexpr int WARP_SIZE = 32;
39+
constexpr uint32_t WARP_MASK = 0xffffffff;
40+
41+
constexpr int RECV_FIFO_DEPTH = 8;
42+
constexpr int RECV_FIFO_ENTRY_BYTES = 256 * 1024;
43+
constexpr int RECV_FIFO_ENTRY_U64 = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t);
44+
constexpr int RECV_FIFO_TOTAL_BYTES = RECV_FIFO_DEPTH * RECV_FIFO_ENTRY_BYTES;
45+
constexpr int RECV_FIFO_TOTAL_U64 = RECV_FIFO_TOTAL_BYTES / sizeof(uint64_t);
46+
47+
class AllToAllChannelCommunicatorBase
48+
{
49+
public:
50+
static constexpr int GROUP_COUNT_PER_BLOCK = 8;
51+
static_assert(GROUP_COUNT_PER_BLOCK <= 8, "GROUP_COUNT_PER_BLOCK must be less than or equal to 8");
52+
static constexpr int WARP_PER_GROUP = 2;
53+
static constexpr int U64_DATA_REG_PER_THREAD = 8;
54+
// A packet is a warp-sized chunk of data that is sent or received in one go,
55+
// but may be split into multiple 64-bit registers, the number of which is U64_DATA_REG_PER_THREAD.
56+
static constexpr int PACKET_SIZE_IN_U64 = WARP_SIZE * U64_DATA_REG_PER_THREAD;
57+
static constexpr int PACKET_SIZE_IN_BYTES = PACKET_SIZE_IN_U64 * sizeof(uint64_t);
58+
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 = (WARP_SIZE - 2) * U64_DATA_REG_PER_THREAD;
59+
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET = DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 * sizeof(uint64_t);
60+
static constexpr int U64_ELT_COUNT_PER_PACKET = PACKET_SIZE_IN_BYTES / sizeof(uint64_t);
61+
62+
static constexpr int PACKET_COUNT_PER_FIFO_ENTRY = RECV_FIFO_ENTRY_BYTES / PACKET_SIZE_IN_BYTES;
63+
64+
static constexpr int GROUP_MAX_INDICE_COUNT
65+
= RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t) / (WARP_SIZE * U64_DATA_REG_PER_THREAD);
66+
67+
struct GroupSharedBuffer
68+
{
69+
int groupIndiceBuffer[GROUP_MAX_INDICE_COUNT];
70+
int groupStartIndice;
71+
int groupEndIndice;
72+
};
73+
74+
static void setMaxUsableSmCount(int maxUsableSmCount)
75+
{
76+
TLLM_CHECK_WITH_INFO(AllToAllChannelCommunicatorBase::maxSmCountUsed == false,
77+
"setMaxUsableSmCount can be called only before it is used");
78+
int smCount = tensorrt_llm::common::getMultiProcessorCount();
79+
if (maxUsableSmCount > smCount)
80+
{
81+
TLLM_LOG_WARNING("setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
82+
maxUsableSmCount, smCount);
83+
maxUsableSmCount = smCount;
84+
}
85+
AllToAllChannelCommunicatorBase::maxSmCount = maxUsableSmCount;
86+
}
87+
88+
static int getMaxUsableSmCount()
89+
{
90+
AllToAllChannelCommunicatorBase::maxSmCountUsed = true;
91+
if (AllToAllChannelCommunicatorBase::maxSmCount == -1)
92+
{
93+
int smCount = tensorrt_llm::common::getMultiProcessorCount();
94+
AllToAllChannelCommunicatorBase::maxSmCount = smCount;
95+
}
96+
return AllToAllChannelCommunicatorBase::maxSmCount;
97+
}
98+
99+
static int computeMoeCommChannelCount(int epSize)
100+
{
101+
int smCount = getMaxUsableSmCount();
102+
int blockCountPerChannel = (epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK;
103+
blockCountPerChannel *= 2; // for send and recv
104+
TLLM_CHECK_WITH_INFO(
105+
blockCountPerChannel <= smCount, "GPU should support at lease one channel, usableSmCount=%d", smCount);
106+
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
107+
int channelCount = std::max(perferredChannel, 1); // at lease one channel
108+
return channelCount;
109+
}
110+
111+
static int getMoeCommChannelCount(int epSize)
112+
{
113+
static std::map<int, int> channelCountMap{};
114+
auto iter = channelCountMap.find(epSize);
115+
if (iter == channelCountMap.end())
116+
{
117+
auto channelCount = AllToAllChannelCommunicatorBase::computeMoeCommChannelCount(epSize);
118+
channelCountMap[epSize] = channelCount;
119+
return channelCount;
120+
}
121+
return iter->second;
122+
}
123+
124+
static dim3 getLaunchBlockDim()
125+
{
126+
return dim3(WARP_SIZE * WARP_PER_GROUP, GROUP_COUNT_PER_BLOCK);
127+
}
128+
129+
static dim3 getLaunchGridDim(int epSize)
130+
{
131+
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
132+
return dim3((epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK, channelCount, 2);
133+
}
134+
135+
protected:
136+
static int maxSmCount;
137+
static bool maxSmCountUsed;
138+
};
139+
140+
inline size_t getMoeCommWorkspaceSize(int epSize)
141+
{
142+
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
143+
return RECV_FIFO_TOTAL_BYTES * epSize * channelCount + sizeof(MoeCommFifoConnInfo) * epSize * channelCount;
144+
}
145+
146+
struct MoeEpWorldInfo
147+
{
148+
int epSize;
149+
int epRank;
150+
};
151+
152+
struct MoeExpertParallelInfo
153+
{
154+
int expertCount = -1;
155+
int topK = 1;
156+
};
157+
158+
struct SendRecvDataInfo
159+
{
160+
int vectorSizeInU64;
161+
// pre-computed at host side for GPU kernel
162+
int dataPacketCountPerVector;
163+
int vectorCountPerFifoEntry;
164+
165+
void ComputeDataPacketCountPerVector()
166+
{
167+
dataPacketCountPerVector
168+
= (vectorSizeInU64 * sizeof(uint64_t) + AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET - 1)
169+
/ AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET;
170+
}
171+
172+
void ComputeVectorCountPerFifoEntry()
173+
{
174+
ComputeDataPacketCountPerVector();
175+
vectorCountPerFifoEntry
176+
= AllToAllChannelCommunicatorBase::PACKET_COUNT_PER_FIFO_ENTRY / dataPacketCountPerVector;
177+
}
178+
179+
void DoPreCompute()
180+
{
181+
ComputeDataPacketCountPerVector();
182+
ComputeVectorCountPerFifoEntry();
183+
assert(vectorCountPerFifoEntry <= AllToAllChannelCommunicatorBase::GROUP_MAX_INDICE_COUNT);
184+
}
185+
};
186+
187+
// struct holding Send/Recv data pointer and its displacement information.
188+
struct SendRecvDispls
189+
{
190+
uint64_t* dataPtr;
191+
int const* rankCountCumSum; // length = epSize
192+
int const* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if epRank > 0 else
193+
// rankCountCumSum[epRank]
194+
int vectorStrideInU64;
195+
196+
#ifdef __CUDACC__
197+
__inline__ __device__ int getCount(int rank) const
198+
{
199+
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
200+
}
201+
202+
__inline__ __device__ int getRankStart(int rank) const
203+
{
204+
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
205+
}
206+
207+
__inline__ __device__ int getRealVectorIndice(int globalVectorIndex) const
208+
{
209+
return rankLocalIndices[globalVectorIndex];
210+
}
211+
212+
__inline__ __device__ uint64_t* getVectorDataPtr(int realVectorIndex) const
213+
{
214+
return dataPtr + realVectorIndex * vectorStrideInU64;
215+
}
216+
#endif
217+
};
218+
219+
struct MoeCommWorkspace
220+
{
221+
uint64_t* workspacePtr;
222+
size_t rankStrideInU64;
223+
#ifdef __CUDACC__
224+
__inline__ __device__ uint64_t* getFifoBasePtr(
225+
bool isSender, int epRank, int peerRank, int channel, int channelCount) const
226+
{
227+
// fifo itself is in receiver's side.
228+
if (isSender)
229+
{
230+
return workspacePtr + peerRank * rankStrideInU64 + (epRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
231+
}
232+
else
233+
{
234+
return workspacePtr + epRank * rankStrideInU64 + (peerRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
235+
}
236+
}
237+
238+
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(
239+
bool isSender, int epRank, int peerRank, int channel, int epSize, int channelCount) const
240+
{
241+
// fifoInfo is in sender's side.
242+
uint64_t* fifoInfoPtrU64 = workspacePtr + RECV_FIFO_TOTAL_U64 * channelCount * epSize;
243+
int strideIndice = isSender ? epRank : peerRank;
244+
int fifoInfoIndice = isSender ? peerRank : epRank;
245+
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
246+
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*) fifoInfoPtrU64;
247+
return fifoInfoPtr + fifoInfoIndice * channelCount + channel;
248+
}
249+
#endif
250+
};
251+
252+
void setMaxUsableSmCount(int smCount);
253+
254+
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
255+
SendRecvDispls recvDispls, MoeCommWorkspace workspace, cudaStream_t stream);
256+
257+
void moeAllToAllPrepareIndices(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
258+
int maxTokenCountPerRank, int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
259+
int* localGatheredIndices, // indices of gatheredTargetRankIds that has the local rank in topK
260+
int* sendRankCountCumSum, int* sendRankLocalIndices, int* recvRankCountCumSum, int* recvRankLocalIndices,
261+
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
262+
int* backwardRecvRankLocalIndices, cudaStream_t stream);
263+
264+
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
265+
int localMaxTokenCount, int const* recvRankCountCumSum, int const* localGatherIndices, int const* gatheredExpertIds,
266+
float const* gatheredScales, int* localExpertIds, float* localScales, cudaStream_t stream);
267+
268+
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ add_library(
5959
logitsBitmaskOp.cpp
6060
mambaConv1dOp.cpp
6161
moeOp.cpp
62+
moeCommOp.cpp
6263
fp8BlockScaleMoe.cpp
6364
fp4BlockScaleMoe.cpp
6465
noAuxTcOp.cpp

0 commit comments

Comments
 (0)