Skip to content

Commit f2ebaf2

Browse files
authored
[None][feat] TRT-LLM Gen MoE optimize DeepSeek Fp8 activation kernel (#9175)
Signed-off-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
1 parent 6dd2fcd commit f2ebaf2

File tree

2 files changed

+306
-44
lines changed

2 files changed

+306
-44
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.cu

Lines changed: 265 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,136 @@ __global__ void activationKernel(KernelParams params)
100100

101101
////////////////////////////////////////////////////////////////////////////////////////////////////
102102

103+
struct Float4Max
104+
{
105+
__device__ __forceinline__ float4 operator()(float4 const& a, float4 const& b) const
106+
{
107+
float4 result;
108+
result.x = fmaxf(a.x, b.x);
109+
result.y = fmaxf(a.y, b.y);
110+
result.z = fmaxf(a.z, b.z);
111+
result.w = fmaxf(a.w, b.w);
112+
return result;
113+
}
114+
};
115+
116+
struct Float2Max
117+
{
118+
__device__ __forceinline__ float2 operator()(float2 const& a, float2 const& b) const
119+
{
120+
float2 result;
121+
result.x = fmaxf(a.x, b.x);
122+
result.y = fmaxf(a.y, b.y);
123+
return result;
124+
}
125+
};
126+
127+
////////////////////////////////////////////////////////////////////////////////////////////////////
128+
129+
template <typename VecType, int size>
130+
__device__ __forceinline__ VecType packedTypeFromArray(float data[size])
131+
{
132+
return {};
133+
}
134+
135+
template <>
136+
__device__ __forceinline__ float4 packedTypeFromArray<float4, 4>(float data[4])
137+
{
138+
float4 result;
139+
result.x = data[0];
140+
result.y = data[1];
141+
result.z = data[2];
142+
result.w = data[3];
143+
return result;
144+
}
145+
146+
template <>
147+
__device__ __forceinline__ float2 packedTypeFromArray<float2, 2>(float data[2])
148+
{
149+
float2 result;
150+
result.x = data[0];
151+
result.y = data[1];
152+
return result;
153+
}
154+
155+
template <>
156+
__device__ __forceinline__ float packedTypeFromArray<float, 1>(float data[1])
157+
{
158+
return data[0];
159+
}
160+
161+
////////////////////////////////////////////////////////////////////////////////////////////////////
162+
163+
template <typename PackedType, int size>
164+
__device__ __forceinline__ cutlass::Array<float, size> arrayFromPackedType(PackedType data)
165+
{
166+
return cutlass::Array<float, size>{};
167+
}
168+
169+
template <>
170+
__device__ __forceinline__ cutlass::Array<float, 4> arrayFromPackedType<float4, 4>(float4 data)
171+
{
172+
return cutlass::Array<float, 4>{data.x, data.y, data.z, data.w};
173+
}
174+
175+
template <>
176+
__device__ __forceinline__ cutlass::Array<float, 2> arrayFromPackedType<float2, 2>(float2 data)
177+
{
178+
return cutlass::Array<float, 2>{data.x, data.y};
179+
}
180+
181+
template <>
182+
__device__ __forceinline__ cutlass::Array<float, 1> arrayFromPackedType<float, 1>(float data)
183+
{
184+
return cutlass::Array<float, 1>{data};
185+
}
186+
187+
////////////////////////////////////////////////////////////////////////////////////////////////////
188+
189+
template <int NUM_TOKENS_PER_CTA>
190+
struct KernelTraits;
191+
192+
template <>
193+
struct KernelTraits<4>
194+
{
195+
using MaxOp = Float4Max;
196+
using PackedType = float4;
197+
};
198+
199+
template <>
200+
struct KernelTraits<2>
201+
{
202+
using MaxOp = Float2Max;
203+
using PackedType = float2;
204+
};
205+
206+
template <>
207+
struct KernelTraits<1>
208+
{
209+
#if CUDA_VERSION >= 12090
210+
using MaxOp = cuda::maximum<>;
211+
#else
212+
using MaxOp = cub::Max;
213+
#endif
214+
using PackedType = float;
215+
};
216+
217+
////////////////////////////////////////////////////////////////////////////////////////////////////
218+
219+
constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128;
220+
103221
template <typename KernelParams>
104222
__global__ void activationDeepSeekKernel(KernelParams params)
105223
{
106224
using Type = typename KernelParams::Type;
107-
using BlockReduce = cub::BlockReduce<float, 128>;
225+
int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta;
226+
using KernelTraits = KernelTraits<NumTokensPerCta>;
227+
using MaxOp = typename KernelTraits::MaxOp;
228+
using PackedType = typename KernelTraits::PackedType;
229+
using BlockReduce = cub::BlockReduce<PackedType, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA>;
108230

109-
__shared__ float s_scaleOut;
110-
__shared__ typename BlockReduce::TempStorage temp_storage;
231+
__shared__ float s_scaleOutArr[NumTokensPerCta];
232+
__shared__ typename BlockReduce::TempStorage tempStorage;
111233

112234
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
113235
// immediately trigger the secondary kernel when using PDL, then wait on primary
@@ -117,55 +239,124 @@ __global__ void activationDeepSeekKernel(KernelParams params)
117239
cudaGridDependencySynchronize();
118240
}
119241
#endif
242+
243+
// The largest (finite) value that can be represented using E4m3.
244+
float constexpr E4m3MaxVal{448.f};
245+
246+
int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];
120247
// Loop over tokens
121-
for (int tokenIdx = blockIdx.z; tokenIdx < params.numTokens; tokenIdx += gridDim.z)
248+
float scale1Arr[NumTokensPerCta];
249+
float scale2Arr[NumTokensPerCta];
250+
float dataX1Arr[NumTokensPerCta];
251+
float dataX2Arr[NumTokensPerCta];
252+
float outArr[NumTokensPerCta];
253+
float absOutArr[NumTokensPerCta];
254+
int permutedIdxArr[NumTokensPerCta];
255+
256+
// Loop over tokens
257+
for (int k = blockIdx.z; k < params.topK; k += gridDim.z)
122258
{
123-
// Look over experts per token
124-
for (int k = blockIdx.y; k < params.topK; k += gridDim.y)
259+
for (int tokenCtaIdx = blockIdx.y * NumTokensPerCta; tokenCtaIdx < params.numTokens;
260+
tokenCtaIdx += gridDim.y * NumTokensPerCta)
125261
{
126-
int const expandedIdx = tokenIdx * params.topK + k;
127-
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
128-
129-
// Needed for expert parallelism
130-
if (permutedIdx == -1)
131-
continue;
132-
133-
// Loop over hidden dim
134262
for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2;
135263
hiddenIdx += blockDim.x * gridDim.x)
136264
{
137-
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
138-
139-
int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];
140-
141-
int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
142-
int const scale2_idx
143-
= permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128));
144-
float const scale1 = params.inDqSfsPtr[scale1_idx];
145-
float const scale2 = params.inDqSfsPtr[scale2_idx];
265+
#pragma unroll
266+
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
267+
{
268+
scale1Arr[tokenInCtaIdx] = 0.0f;
269+
scale2Arr[tokenInCtaIdx] = 0.0f;
270+
dataX1Arr[tokenInCtaIdx] = 0.0f;
271+
dataX2Arr[tokenInCtaIdx] = 0.0f;
272+
outArr[tokenInCtaIdx] = 0.0f;
273+
absOutArr[tokenInCtaIdx] = 0.0f;
274+
}
275+
#pragma unroll
276+
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
277+
{
278+
int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
279+
if (tokenIdx >= params.numTokens)
280+
{
281+
break;
282+
}
146283

147-
float x1 = scale1 * (float) params.inPtr[baseIdx];
148-
float x2 = scale2 * (float) params.inPtr[baseIdx + params.innerDim / 2];
284+
int const expandedIdx = tokenIdx * params.topK + k;
285+
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
286+
permutedIdxArr[tokenInCtaIdx] = permutedIdx;
287+
if (permutedIdx == -1)
288+
{
289+
continue;
290+
}
291+
292+
// Process blocks for this CTA
293+
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
294+
295+
int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
296+
int const scale2Idx
297+
= permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128));
298+
299+
scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx];
300+
scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx];
301+
dataX1Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx]);
302+
dataX2Arr[tokenInCtaIdx] = static_cast<float>(params.inPtr[baseIdx + params.innerDim / 2]);
303+
}
149304

150-
float act = silu(x2);
151-
float out = act * x1;
305+
#pragma unroll
306+
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
307+
{
308+
float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
309+
float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
310+
float act = silu(x2);
311+
float out = act * x1;
312+
outArr[tokenInCtaIdx] = out;
313+
absOutArr[tokenInCtaIdx] = fabsf(out);
314+
}
152315

153-
// The largest (finite) value that can be represented using E4m3.
154-
float constexpr E4m3MaxVal{448.f};
316+
auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr);
317+
auto aMaxPacked = BlockReduce(tempStorage).Reduce(absOutPacked, MaxOp{});
318+
auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked);
155319

156-
// Compute the absolute max
157-
float aMax = BlockReduce(temp_storage).Reduce(fabsf(out), cuda::maximum<>());
158-
if (threadIdx.x == 0)
320+
#pragma unroll
321+
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
159322
{
160-
s_scaleOut = aMax / E4m3MaxVal;
161-
int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
162-
params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal;
323+
if (threadIdx.x == 0)
324+
{
325+
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
326+
if (tokenIdx >= params.numTokens)
327+
{
328+
break;
329+
}
330+
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
331+
if (permutedIdx == -1)
332+
{
333+
continue;
334+
}
335+
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
336+
int const scaleOut_idx
337+
= permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
338+
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
339+
}
163340
}
164341
__syncthreads();
165-
float const scaleOut = s_scaleOut;
166-
__syncthreads();
167-
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
168-
params.outPtr[outIdx] = (Type) (out / scaleOut);
342+
343+
#pragma unroll
344+
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
345+
{
346+
auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
347+
if (tokenIdx >= params.numTokens)
348+
{
349+
break;
350+
}
351+
int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
352+
if (permutedIdx == -1)
353+
{
354+
continue;
355+
}
356+
float const scaleOut = s_scaleOutArr[tokenInCtaIdx];
357+
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
358+
params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut);
359+
}
169360
}
170361
}
171362
}
@@ -185,17 +376,48 @@ void run(Data const& data, void* stream)
185376

186377
if (data.mUseDeepSeekFp8)
187378
{
188-
int const numThreads = 128;
189-
const dim3 grid(data.innerDim / 128, data.topK, std::min(8192, data.numTokens));
379+
constexpr int NUM_ELTS_PER_LOAD = 1;
380+
constexpr int NUM_ELTS_PER_SF = 128;
381+
382+
int device{-1};
383+
cudaGetDevice(&device);
384+
int numSms = 0;
385+
cudaDeviceGetAttribute(&numSms, cudaDevAttrMultiProcessorCount, device);
386+
387+
// Output dimension is innerDim / 2, and each scale block is 128 elements
388+
int const outputDim = data.innerDim / 2;
389+
int const numScaleBlocks = (outputDim + NUM_ELTS_PER_SF - 1) / NUM_ELTS_PER_SF;
390+
int const gridSizeX = (numScaleBlocks + NUM_ELTS_PER_LOAD - 1) / NUM_ELTS_PER_LOAD;
391+
392+
auto numCtas = gridSizeX * data.numTokens * data.topK;
393+
// FIXME: This is heruistic based on very short benchmark.
394+
int numTokensPerCta = 1;
395+
if (numCtas > numSms * 32)
396+
{
397+
numTokensPerCta = 4;
398+
}
399+
else if (numCtas > numSms * 4)
400+
{
401+
numTokensPerCta = 2;
402+
}
403+
else
404+
{
405+
numTokensPerCta = 1;
406+
}
407+
408+
int const gridSizeY = std::min(8192, (data.numTokens + numTokensPerCta - 1) / numTokensPerCta);
409+
410+
const dim3 grid(gridSizeX, gridSizeY, data.topK);
190411

191-
LAUNCH(data, activationDeepSeekKernel, grid, numThreads, 0, stream);
412+
LAUNCH_ACTIVATION(
413+
data, activationDeepSeekKernel, numTokensPerCta, grid, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream);
192414
}
193415
else
194416
{
195417
int const numThreads = 256;
196418
const dim3 grid(data.innerDim / 128, data.topK, std::min(8192, data.numTokens));
197419

198-
LAUNCH(data, activationKernel, grid, numThreads, 0, stream);
420+
LAUNCH_ACTIVATION(data, activationKernel, 1, grid, numThreads, 0, stream);
199421
}
200422
}
201423

0 commit comments

Comments
 (0)