@@ -100,14 +100,136 @@ __global__ void activationKernel(KernelParams params)
100100
101101// //////////////////////////////////////////////////////////////////////////////////////////////////
102102
103+ struct Float4Max
104+ {
105+ __device__ __forceinline__ float4 operator ()(float4 const & a, float4 const & b) const
106+ {
107+ float4 result;
108+ result.x = fmaxf (a.x , b.x );
109+ result.y = fmaxf (a.y , b.y );
110+ result.z = fmaxf (a.z , b.z );
111+ result.w = fmaxf (a.w , b.w );
112+ return result;
113+ }
114+ };
115+
116+ struct Float2Max
117+ {
118+ __device__ __forceinline__ float2 operator ()(float2 const & a, float2 const & b) const
119+ {
120+ float2 result;
121+ result.x = fmaxf (a.x , b.x );
122+ result.y = fmaxf (a.y , b.y );
123+ return result;
124+ }
125+ };
126+
127+ // //////////////////////////////////////////////////////////////////////////////////////////////////
128+
129+ template <typename VecType, int size>
130+ __device__ __forceinline__ VecType packedTypeFromArray (float data[size])
131+ {
132+ return {};
133+ }
134+
135+ template <>
136+ __device__ __forceinline__ float4 packedTypeFromArray<float4 , 4 >(float data[4 ])
137+ {
138+ float4 result;
139+ result.x = data[0 ];
140+ result.y = data[1 ];
141+ result.z = data[2 ];
142+ result.w = data[3 ];
143+ return result;
144+ }
145+
146+ template <>
147+ __device__ __forceinline__ float2 packedTypeFromArray<float2 , 2 >(float data[2 ])
148+ {
149+ float2 result;
150+ result.x = data[0 ];
151+ result.y = data[1 ];
152+ return result;
153+ }
154+
155+ template <>
156+ __device__ __forceinline__ float packedTypeFromArray<float , 1 >(float data[1 ])
157+ {
158+ return data[0 ];
159+ }
160+
161+ // //////////////////////////////////////////////////////////////////////////////////////////////////
162+
163+ template <typename PackedType, int size>
164+ __device__ __forceinline__ cutlass::Array<float , size> arrayFromPackedType (PackedType data)
165+ {
166+ return cutlass::Array<float , size>{};
167+ }
168+
169+ template <>
170+ __device__ __forceinline__ cutlass::Array<float , 4 > arrayFromPackedType<float4 , 4 >(float4 data)
171+ {
172+ return cutlass::Array<float , 4 >{data.x , data.y , data.z , data.w };
173+ }
174+
175+ template <>
176+ __device__ __forceinline__ cutlass::Array<float , 2 > arrayFromPackedType<float2 , 2 >(float2 data)
177+ {
178+ return cutlass::Array<float , 2 >{data.x , data.y };
179+ }
180+
181+ template <>
182+ __device__ __forceinline__ cutlass::Array<float , 1 > arrayFromPackedType<float , 1 >(float data)
183+ {
184+ return cutlass::Array<float , 1 >{data};
185+ }
186+
187+ // //////////////////////////////////////////////////////////////////////////////////////////////////
188+
189+ template <int NUM_TOKENS_PER_CTA>
190+ struct KernelTraits ;
191+
192+ template <>
193+ struct KernelTraits <4 >
194+ {
195+ using MaxOp = Float4Max;
196+ using PackedType = float4 ;
197+ };
198+
199+ template <>
200+ struct KernelTraits <2 >
201+ {
202+ using MaxOp = Float2Max;
203+ using PackedType = float2 ;
204+ };
205+
206+ template <>
207+ struct KernelTraits <1 >
208+ {
209+ #if CUDA_VERSION >= 12090
210+ using MaxOp = cuda::maximum<>;
211+ #else
212+ using MaxOp = cub::Max;
213+ #endif
214+ using PackedType = float ;
215+ };
216+
217+ // //////////////////////////////////////////////////////////////////////////////////////////////////
218+
219+ constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128 ;
220+
103221template <typename KernelParams>
104222__global__ void activationDeepSeekKernel (KernelParams params)
105223{
106224 using Type = typename KernelParams::Type;
107- using BlockReduce = cub::BlockReduce<float , 128 >;
225+ int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta;
226+ using KernelTraits = KernelTraits<NumTokensPerCta>;
227+ using MaxOp = typename KernelTraits::MaxOp;
228+ using PackedType = typename KernelTraits::PackedType;
229+ using BlockReduce = cub::BlockReduce<PackedType, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA>;
108230
109- __shared__ float s_scaleOut ;
110- __shared__ typename BlockReduce::TempStorage temp_storage ;
231+ __shared__ float s_scaleOutArr[NumTokensPerCta] ;
232+ __shared__ typename BlockReduce::TempStorage tempStorage ;
111233
112234#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
113235 // immediately trigger the secondary kernel when using PDL, then wait on primary
@@ -117,55 +239,124 @@ __global__ void activationDeepSeekKernel(KernelParams params)
117239 cudaGridDependencySynchronize ();
118240 }
119241#endif
242+
243+ // The largest (finite) value that can be represented using E4m3.
244+ float constexpr E4m3MaxVal{448 .f };
245+
246+ int const totalNumPaddedTokens = params.totalNumPaddedTokens [0 ];
120247 // Loop over tokens
121- for (int tokenIdx = blockIdx .z ; tokenIdx < params.numTokens ; tokenIdx += gridDim .z )
248+ float scale1Arr[NumTokensPerCta];
249+ float scale2Arr[NumTokensPerCta];
250+ float dataX1Arr[NumTokensPerCta];
251+ float dataX2Arr[NumTokensPerCta];
252+ float outArr[NumTokensPerCta];
253+ float absOutArr[NumTokensPerCta];
254+ int permutedIdxArr[NumTokensPerCta];
255+
256+ // Loop over tokens
257+ for (int k = blockIdx .z ; k < params.topK ; k += gridDim .z )
122258 {
123- // Look over experts per token
124- for ( int k = blockIdx . y ; k < params. topK ; k += gridDim .y )
259+ for ( int tokenCtaIdx = blockIdx . y * NumTokensPerCta; tokenCtaIdx < params. numTokens ;
260+ tokenCtaIdx += gridDim .y * NumTokensPerCta )
125261 {
126- int const expandedIdx = tokenIdx * params.topK + k;
127- int const permutedIdx = params.expandedIdxToPermutedIdx [expandedIdx];
128-
129- // Needed for expert parallelism
130- if (permutedIdx == -1 )
131- continue ;
132-
133- // Loop over hidden dim
134262 for (int hiddenIdx = threadIdx .x + blockDim .x * blockIdx .x ; hiddenIdx < params.innerDim / 2 ;
135263 hiddenIdx += blockDim .x * gridDim .x )
136264 {
137- int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
138-
139- int const totalNumPaddedTokens = params.totalNumPaddedTokens [0 ];
140-
141- int const scale1_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128 );
142- int const scale2_idx
143- = permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128 ) + (params.innerDim / 2 / 128 ));
144- float const scale1 = params.inDqSfsPtr [scale1_idx];
145- float const scale2 = params.inDqSfsPtr [scale2_idx];
265+ #pragma unroll
266+ for (int tokenInCtaIdx = 0 ; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
267+ {
268+ scale1Arr[tokenInCtaIdx] = 0 .0f ;
269+ scale2Arr[tokenInCtaIdx] = 0 .0f ;
270+ dataX1Arr[tokenInCtaIdx] = 0 .0f ;
271+ dataX2Arr[tokenInCtaIdx] = 0 .0f ;
272+ outArr[tokenInCtaIdx] = 0 .0f ;
273+ absOutArr[tokenInCtaIdx] = 0 .0f ;
274+ }
275+ #pragma unroll
276+ for (int tokenInCtaIdx = 0 ; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
277+ {
278+ int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
279+ if (tokenIdx >= params.numTokens )
280+ {
281+ break ;
282+ }
146283
147- float x1 = scale1 * (float ) params.inPtr [baseIdx];
148- float x2 = scale2 * (float ) params.inPtr [baseIdx + params.innerDim / 2 ];
284+ int const expandedIdx = tokenIdx * params.topK + k;
285+ int const permutedIdx = params.expandedIdxToPermutedIdx [expandedIdx];
286+ permutedIdxArr[tokenInCtaIdx] = permutedIdx;
287+ if (permutedIdx == -1 )
288+ {
289+ continue ;
290+ }
291+
292+ // Process blocks for this CTA
293+ int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
294+
295+ int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128 );
296+ int const scale2Idx
297+ = permutedIdx + totalNumPaddedTokens * ((hiddenIdx / 128 ) + (params.innerDim / 2 / 128 ));
298+
299+ scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr [scale1Idx];
300+ scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr [scale2Idx];
301+ dataX1Arr[tokenInCtaIdx] = static_cast <float >(params.inPtr [baseIdx]);
302+ dataX2Arr[tokenInCtaIdx] = static_cast <float >(params.inPtr [baseIdx + params.innerDim / 2 ]);
303+ }
149304
150- float act = silu (x2);
151- float out = act * x1;
305+ #pragma unroll
306+ for (int tokenInCtaIdx = 0 ; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
307+ {
308+ float x1 = scale1Arr[tokenInCtaIdx] * dataX1Arr[tokenInCtaIdx];
309+ float x2 = scale2Arr[tokenInCtaIdx] * dataX2Arr[tokenInCtaIdx];
310+ float act = silu (x2);
311+ float out = act * x1;
312+ outArr[tokenInCtaIdx] = out;
313+ absOutArr[tokenInCtaIdx] = fabsf (out);
314+ }
152315
153- // The largest (finite) value that can be represented using E4m3.
154- float constexpr E4m3MaxVal{448 .f };
316+ auto absOutPacked = packedTypeFromArray<PackedType, NumTokensPerCta>(absOutArr);
317+ auto aMaxPacked = BlockReduce (tempStorage).Reduce (absOutPacked, MaxOp{});
318+ auto aMaxArr = arrayFromPackedType<PackedType, NumTokensPerCta>(aMaxPacked);
155319
156- // Compute the absolute max
157- float aMax = BlockReduce (temp_storage).Reduce (fabsf (out), cuda::maximum<>());
158- if (threadIdx .x == 0 )
320+ #pragma unroll
321+ for (int tokenInCtaIdx = 0 ; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
159322 {
160- s_scaleOut = aMax / E4m3MaxVal;
161- int const scaleOut_idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128 );
162- params.outDqSfsPtr [scaleOut_idx] = aMax / E4m3MaxVal;
323+ if (threadIdx .x == 0 )
324+ {
325+ auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
326+ if (tokenIdx >= params.numTokens )
327+ {
328+ break ;
329+ }
330+ int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
331+ if (permutedIdx == -1 )
332+ {
333+ continue ;
334+ }
335+ s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
336+ int const scaleOut_idx
337+ = permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128 );
338+ params.outDqSfsPtr [scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
339+ }
163340 }
164341 __syncthreads ();
165- float const scaleOut = s_scaleOut;
166- __syncthreads ();
167- int const outIdx = permutedIdx * (params.innerDim / 2 ) + hiddenIdx;
168- params.outPtr [outIdx] = (Type) (out / scaleOut);
342+
343+ #pragma unroll
344+ for (int tokenInCtaIdx = 0 ; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++)
345+ {
346+ auto const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
347+ if (tokenIdx >= params.numTokens )
348+ {
349+ break ;
350+ }
351+ int const permutedIdx = permutedIdxArr[tokenInCtaIdx];
352+ if (permutedIdx == -1 )
353+ {
354+ continue ;
355+ }
356+ float const scaleOut = s_scaleOutArr[tokenInCtaIdx];
357+ int const outIdx = permutedIdx * (params.innerDim / 2 ) + hiddenIdx;
358+ params.outPtr [outIdx] = static_cast <Type>(outArr[tokenInCtaIdx] / scaleOut);
359+ }
169360 }
170361 }
171362 }
@@ -185,17 +376,48 @@ void run(Data const& data, void* stream)
185376
186377 if (data.mUseDeepSeekFp8 )
187378 {
188- int const numThreads = 128 ;
189- const dim3 grid (data.innerDim / 128 , data.topK , std::min (8192 , data.numTokens ));
379+ constexpr int NUM_ELTS_PER_LOAD = 1 ;
380+ constexpr int NUM_ELTS_PER_SF = 128 ;
381+
382+ int device{-1 };
383+ cudaGetDevice (&device);
384+ int numSms = 0 ;
385+ cudaDeviceGetAttribute (&numSms, cudaDevAttrMultiProcessorCount, device);
386+
387+ // Output dimension is innerDim / 2, and each scale block is 128 elements
388+ int const outputDim = data.innerDim / 2 ;
389+ int const numScaleBlocks = (outputDim + NUM_ELTS_PER_SF - 1 ) / NUM_ELTS_PER_SF;
390+ int const gridSizeX = (numScaleBlocks + NUM_ELTS_PER_LOAD - 1 ) / NUM_ELTS_PER_LOAD;
391+
392+ auto numCtas = gridSizeX * data.numTokens * data.topK ;
393+ // FIXME: This is heruistic based on very short benchmark.
394+ int numTokensPerCta = 1 ;
395+ if (numCtas > numSms * 32 )
396+ {
397+ numTokensPerCta = 4 ;
398+ }
399+ else if (numCtas > numSms * 4 )
400+ {
401+ numTokensPerCta = 2 ;
402+ }
403+ else
404+ {
405+ numTokensPerCta = 1 ;
406+ }
407+
408+ int const gridSizeY = std::min (8192 , (data.numTokens + numTokensPerCta - 1 ) / numTokensPerCta);
409+
410+ const dim3 grid (gridSizeX, gridSizeY, data.topK );
190411
191- LAUNCH (data, activationDeepSeekKernel, grid, numThreads, 0 , stream);
412+ LAUNCH_ACTIVATION (
413+ data, activationDeepSeekKernel, numTokensPerCta, grid, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0 , stream);
192414 }
193415 else
194416 {
195417 int const numThreads = 256 ;
196418 const dim3 grid (data.innerDim / 128 , data.topK , std::min (8192 , data.numTokens ));
197419
198- LAUNCH (data, activationKernel, grid, numThreads, 0 , stream);
420+ LAUNCH_ACTIVATION (data, activationKernel, 1 , grid, numThreads, 0 , stream);
199421 }
200422}
201423
0 commit comments