forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmoeUtils.cu
More file actions
460 lines (414 loc) · 20.2 KB
/
moeUtils.cu
File metadata and controls
460 lines (414 loc) · 20.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/cuteDslKernels/moeUtils.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh"
#include "tensorrt_llm/kernels/quantization.cuh"
#include "tensorrt_llm/kernels/quantization.h"
#include <cuda_fp4.h>
#include <cute/numeric/numeric_types.hpp>
namespace tensorrt_llm::kernels::cute_dsl
{
namespace
{
using ElemCopyType = uint4;
using SFCopyType = uint32_t;
using ActivationType = tensorrt_llm::kernels::cutlass_kernels::ActivationType;
template <typename T>
auto constexpr bitsPerElem()
{
#ifdef ENABLE_FP4
return std::is_same_v<T, __nv_fp4_e2m1> ? 4 : cute::sizeof_bits_v<T>;
#else
return cute::sizeof_bits_v<T>;
#endif
}
template <typename T>
auto constexpr elemPerCopy()
{
return bitsPerElem<ElemCopyType>() / bitsPerElem<T>();
}
template <typename T>
auto constexpr sfElemPerCopy()
{
return bitsPerElem<SFCopyType>() / bitsPerElem<T>();
}
} // namespace
template <typename InputType, typename SFType, int32_t kSFVecSize, int32_t kThreadsPerBlock>
__global__ void moePermuteKernel(InputType const* input, InputType* permuted_output, SFType const* input_sf,
SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx,
int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size)
{
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
int32_t constexpr kSFElemPerCopy = sfElemPerCopy<SFType>();
// Need int64_t to prevent overflow when computing pointer offsets.
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
{
int32_t const tile_idx = permuted_idx / tile_size;
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
{
continue;
}
int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx];
int32_t const token_idx = expanded_idx / top_k;
auto const* src_ptr = reinterpret_cast<ElemCopyType const*>(input) + token_idx * kCopyPerToken;
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(permuted_output) + permuted_idx * kCopyPerToken;
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
{
dst_ptr[i] = src_ptr[i];
}
#ifdef ENABLE_FP4
if constexpr (std::is_same_v<InputType, __nv_fp4_e2m1>)
{
int32_t const sf_hidden_size = hidden_size / kSFVecSize;
int64_t const kSFCopyPerToken = sf_hidden_size / kSFElemPerCopy;
auto const* sf_src_ptr = reinterpret_cast<SFCopyType const*>(input_sf);
auto* sf_dst_ptr = reinterpret_cast<SFCopyType*>(permuted_sf);
for (int32_t i = threadIdx.x; i < kSFCopyPerToken; i += kThreadsPerBlock)
{
// input_sf is not swizzled, while permuted_sf is swizzled.
int64_t const src_offset = token_idx * kSFCopyPerToken + i;
int64_t const dst_offset = get_sf_out_offset_128x4(/* batchIdx= */ std::nullopt, permuted_idx,
i * kSFElemPerCopy, /* numRows= */ std::nullopt, sf_hidden_size)
/ kSFElemPerCopy;
sf_dst_ptr[dst_offset] = sf_src_ptr[src_offset];
}
}
#endif
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename InputType, typename SFType>
void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, SFType* permuted_sf,
int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx,
int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size,
int32_t const top_k, int32_t const tile_size, cudaStream_t stream)
{
int32_t constexpr kThreadsPerBlock = 256;
int32_t constexpr kSFVecSize = 16;
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);
#ifdef ENABLE_FP4
if constexpr (std::is_same_v<InputType, __nv_fp4_e2m1>)
{
int32_t constexpr kSFMAlignment = 128;
int32_t constexpr kSFKAlignment = 4;
int32_t constexpr kSFElemPerCopy = sfElemPerCopy<SFType>();
static_assert(kSFElemPerCopy == kSFKAlignment);
TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0,
"max_num_permuted_tokens must be divisible by %d.", kSFMAlignment);
TLLM_CHECK_WITH_INFO(hidden_size % (kSFVecSize * kSFKAlignment) == 0, "hidden_size must be divisible by %d.",
kSFVecSize * kSFKAlignment);
TLLM_CHECK_WITH_INFO(input_sf != nullptr, "input_sf is required for NVFP4.");
TLLM_CHECK_WITH_INFO(permuted_sf != nullptr, "permuted_sf is required for NVFP4.");
}
#endif
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf, tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size);
}
#define INSTANTIATE_MOE_PERMUTE(InputType, SFType) \
template void moePermute<InputType, SFType>(InputType const* input, InputType* permuted_output, \
SFType const* input_sf, SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, \
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, \
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, \
int32_t const tile_size, cudaStream_t stream)
INSTANTIATE_MOE_PERMUTE(half, uint8_t);
#ifdef ENABLE_BF16
INSTANTIATE_MOE_PERMUTE(__nv_bfloat16, uint8_t);
#endif
#ifdef ENABLE_FP8
INSTANTIATE_MOE_PERMUTE(__nv_fp8_e4m3, uint8_t);
#endif
#ifdef ENABLE_FP4
INSTANTIATE_MOE_PERMUTE(__nv_fp4_e2m1, uint8_t);
#endif
#undef INSTANTIATE_MOE_PERMUTE
template <typename InputType, typename TopKScaleType, int32_t kThreadsPerBlock>
__global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* output,
int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const hidden_size,
int32_t const top_k)
{
using AccumType = float;
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
// Need int64_t to prevent overflow when computing pointer offsets.
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
InputType rmem[kElemPerCopy];
AccumType rmemAccum[kElemPerCopy];
int32_t const token_idx = blockIdx.x;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(output) + token_idx * kCopyPerToken;
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
{
#pragma unroll
for (int32_t j = 0; j < kElemPerCopy; j++)
{
rmemAccum[j] = 0;
}
for (int32_t k = 0; k < top_k; k++)
{
int32_t const permuted_idx = expanded_idx_to_permuted_idx[token_idx * top_k + k];
if (permuted_idx < 0)
{
continue;
}
auto const* src_ptr = reinterpret_cast<ElemCopyType const*>(permuted_input) + permuted_idx * kCopyPerToken;
*reinterpret_cast<ElemCopyType*>(rmem) = src_ptr[i];
TopKScaleType const scale = topk_scales[token_idx * top_k + k];
#pragma unroll
for (int32_t j = 0; j < kElemPerCopy; j++)
{
rmemAccum[j] += static_cast<AccumType>(rmem[j]) * static_cast<AccumType>(scale);
}
}
#pragma unroll
for (int32_t j = 0; j < kElemPerCopy; j++)
{
rmem[j] = static_cast<InputType>(rmemAccum[j]);
}
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename InputType, typename TopKScaleType>
void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t const* expanded_idx_to_permuted_idx,
TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k,
cudaStream_t stream)
{
int32_t constexpr kThreadsPerBlock = 256;
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);
int32_t const blocks = num_tokens;
int32_t const threads = kThreadsPerBlock;
auto kernel = &moeUnpermuteKernel<InputType, TopKScaleType, kThreadsPerBlock>;
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(
&config, kernel, permuted_input, output, expanded_idx_to_permuted_idx, topk_scales, hidden_size, top_k);
}
#define INSTANTIATE_MOE_UNPERMUTE(InputType, TopKScaleType) \
template void moeUnpermute<InputType>(InputType const* permuted_input, InputType* output, \
int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const num_tokens, \
int32_t const hidden_size, int32_t const top_k, cudaStream_t stream)
INSTANTIATE_MOE_UNPERMUTE(half, float);
INSTANTIATE_MOE_UNPERMUTE(half, half);
#ifdef ENABLE_BF16
INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, float);
INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
#endif
#undef INSTANTIATE_MOE_UNPERMUTE
template <typename InputType, typename OutputType, typename SFType, int32_t kSFVecSize, typename ActFn,
int32_t kThreadsPerBlock>
__global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf,
SFType* output_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,
int32_t const interm_size, int32_t const tile_size)
{
using ComputeType = float;
#ifdef ENABLE_FP4
using ElemOutputCopyType = std::conditional_t<std::is_same_v<OutputType, __nv_fp4_e2m1>, uint32_t, ElemCopyType>;
#else
using ElemOutputCopyType = ElemCopyType;
#endif
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
// Need int64_t to prevent overflow when computing pointer offsets.
int64_t const kCopyPerToken = interm_size / kElemPerCopy;
InputType rmem[kElemPerCopy];
InputType rmemGate[kElemPerCopy];
ActFn act{};
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0];
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
{
int32_t const tile_idx = permuted_idx / tile_size;
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
{
continue;
}
auto const* src_ptr
= reinterpret_cast<ElemCopyType const*>(input) + permuted_idx * kCopyPerToken * (ActFn::IS_GLU ? 2 : 1);
auto* dst_ptr = reinterpret_cast<ElemOutputCopyType*>(output) + permuted_idx * kCopyPerToken;
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
{
*reinterpret_cast<ElemCopyType*>(rmem) = src_ptr[i];
if constexpr (ActFn::IS_GLU)
{
*reinterpret_cast<ElemCopyType*>(rmemGate) = src_ptr[i + kCopyPerToken];
#pragma unroll
for (int32_t j = 0; j < kElemPerCopy; j++)
{
rmem[j] = static_cast<InputType>(
act(static_cast<ComputeType>(rmemGate[j]), static_cast<ComputeType>(rmem[j])));
}
}
else
{
#pragma unroll
for (int32_t j = 0; j < kElemPerCopy; j++)
{
rmem[j] = static_cast<InputType>(act(static_cast<ComputeType>(rmem[j])));
}
}
#ifdef ENABLE_FP4
if constexpr (std::is_same_v<OutputType, __nv_fp4_e2m1>)
{
auto* sf_dst_ptr = cvt_quant_get_sf_out_offset<SFType, kSFVecSize / kElemPerCopy>(
/* batchIdx= */ std::nullopt, permuted_idx, i, /*numRows=*/std::nullopt, interm_size / kSFVecSize,
output_sf, QuantizationSFLayout::SWIZZLED);
dst_ptr[i] = cvt_warp_fp16_to_fp4<InputType, kSFVecSize, false>(
*reinterpret_cast<PackedVec<InputType>*>(rmem), global_sf_val, sf_dst_ptr);
}
else
#endif
{
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename InputType, typename OutputType, typename SFType>
void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,
cutlass_kernels::ActivationParams activation_params, int32_t const max_num_permuted_tokens,
int32_t const interm_size, int32_t const tile_size, cudaStream_t stream)
{
int32_t constexpr kThreadsPerBlock = 256;
int32_t constexpr kSFVecSize = 16;
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
TLLM_CHECK_WITH_INFO(interm_size % kElemPerCopy == 0, "interm_size must be divisible by %d.", kElemPerCopy);
#ifdef ENABLE_FP4
if constexpr (std::is_same_v<InputType, __nv_fp4_e2m1>)
{
int32_t constexpr kSFMAlignment = 128;
int32_t constexpr kSFKAlignment = 4;
TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0,
"max_num_permuted_tokens must be divisible by %d.", kSFMAlignment);
TLLM_CHECK_WITH_INFO(interm_size % (kSFVecSize * kSFKAlignment) == 0, "interm_size must be divisible by %d.",
kSFVecSize * kSFKAlignment);
TLLM_CHECK_WITH_INFO(global_sf != nullptr, "global_sf is required for NVFP4.");
TLLM_CHECK_WITH_INFO(output_sf != nullptr, "output_sf is required for NVFP4.");
}
#endif
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit,
int32_t const* num_non_exiting_tiles,
int32_t const interm_size, int32_t const tile_size)
{
switch (activation_type)
{
case ActivationType::Identity:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>;
case ActivationType::Gelu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
case ActivationType::Geglu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
case ActivationType::Relu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>;
case ActivationType::Silu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
case ActivationType::Swiglu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
case ActivationType::SwigluBias:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
kThreadsPerBlock>;
case ActivationType::Relu2:
// Unsupported activation type
break;
}
TLLM_CHECK_WITH_INFO(false, "Unsupported activation type: %d", int(activation_type));
return nullptr;
};
auto kernel = get_act_kernel(activation_params.activation_type);
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel, input, output, global_sf, output_sf, tile_idx_to_mn_limit,
num_non_exiting_tiles, interm_size, tile_size);
}
#define INSTANTIATE_MOE_ACTIVATION(InputType, OutputType, SFType) \
template void moeActivation<InputType, OutputType, SFType>(InputType const* input, OutputType* output, \
float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, \
int32_t const* num_non_exiting_tiles, cutlass_kernels::ActivationParams activation_params, \
int32_t const max_num_permuted_tokens, int32_t const interm_size, int32_t const tile_size, \
cudaStream_t stream)
INSTANTIATE_MOE_ACTIVATION(half, half, uint8_t);
#ifdef ENABLE_BF16
INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_bfloat16, uint8_t);
#endif
#ifdef ENABLE_FP4
INSTANTIATE_MOE_ACTIVATION(half, __nv_fp4_e2m1, uint8_t);
#ifdef ENABLE_BF16
INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_fp4_e2m1, uint8_t);
#endif
#endif
#undef INSTANTIATE_MOE_ACTIVATION
} // namespace tensorrt_llm::kernels::cute_dsl