Skip to content

Commit 68b9d9b

Browse files
authored
[CUDA] BF16 MoE and qMoE (microsoft#25572)
Add support of bfloat16 in MoE and qMoE cuda ops.
1 parent 866c7e3 commit 68b9d9b

File tree

14 files changed

+544
-391
lines changed

14 files changed

+544
-391
lines changed

docs/ContribOperators.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3129,8 +3129,8 @@ This version of the operator has been available since version 1 of the 'com.micr
31293129
#### Type Constraints
31303130

31313131
<dl>
3132-
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
3133-
<dd>Constrain input and output types to float or float16 tensors.</dd>
3132+
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
3133+
<dd>Constrain input and output types to float tensors.</dd>
31343134
</dl>
31353135

31363136

@@ -4543,19 +4543,19 @@ This version of the operator has been available since version 1 of the 'com.micr
45434543
<dd>2D input tensor with shape (num_rows, num_experts)</dd>
45444544
<dt><tt>fc1_experts_weights</tt> : T1</dt>
45454545
<dd>3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).</dd>
4546-
<dt><tt>fc1_scales</tt> : T</dt>
4546+
<dt><tt>fc1_scales</tt> : T2</dt>
45474547
<dd>2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
45484548
<dt><tt>fc1_experts_bias</tt> (optional) : T</dt>
45494549
<dd>2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu</dd>
45504550
<dt><tt>fc2_experts_weights</tt> : T1</dt>
45514551
<dd>3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)</dd>
4552-
<dt><tt>fc2_scales</tt> : T</dt>
4552+
<dt><tt>fc2_scales</tt> : T2</dt>
45534553
<dd>2D input tensor with shape (num_experts, hidden_size)</dd>
45544554
<dt><tt>fc2_experts_bias</tt> (optional) : T</dt>
45554555
<dd>2D optional input tensor with shape (num_experts, hidden_size)</dd>
45564556
<dt><tt>fc3_experts_weights</tt> (optional) : T1</dt>
45574557
<dd>3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)</dd>
4558-
<dt><tt>fc3_scales</tt> (optional) : T</dt>
4558+
<dt><tt>fc3_scales</tt> (optional) : T2</dt>
45594559
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
45604560
<dt><tt>fc3_experts_bias</tt> (optional) : T</dt>
45614561
<dd>2D optional input tensor with shape (num_experts, inter_size)</dd>
@@ -4571,10 +4571,12 @@ This version of the operator has been available since version 1 of the 'com.micr
45714571
#### Type Constraints
45724572

45734573
<dl>
4574-
<dt><tt>T</tt> : tensor(float16)</dt>
4575-
<dd>Constrain input and output types to float or float16 tensors.</dd>
4574+
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16)</dt>
4575+
<dd>Constrain input and output types to float tensors.</dd>
45764576
<dt><tt>T1</tt> : tensor(uint8)</dt>
45774577
<dd>Constrain weights type to uint8 tensors.</dd>
4578+
<dt><tt>T2</tt> : tensor(float), tensor(float16)</dt>
4579+
<dd>Constrain scales type to float tensors.</dd>
45784580
</dl>
45794581

45804582

docs/OperatorKernels.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,15 +949,15 @@ Do not modify directly.*
949949
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
950950
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
951951
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)|
952-
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
952+
|MoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
953953
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
954954
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
955955
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
956956
|PackedAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
957957
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
958958
|PagedAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* cumulative_sequence_length:**S**<br> *in* past_seqlens:**S**<br> *in* block_table:**S**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* key_cache_out:**T**<br> *out* value_cache_out:**T**|1+|**S** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
959959
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
960-
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float16)<br/> **T1** = tensor(uint8)|
960+
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(float16)|
961961
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* attention_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
962962
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
963963
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|

onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop);
9292
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop);
9393
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE);
9494
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE);
95-
class CUDA_MS_OP_CLASS_NAME(1, QMoE);
95+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE);
96+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE);
97+
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE);
9698
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention);
9799
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention);
98100
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention);
@@ -307,7 +309,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
307309
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop)>,
308310
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE)>,
309311
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE)>,
310-
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, QMoE)>,
312+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE)>,
313+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE)>,
314+
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE)>,
311315
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention)>,
312316
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention)>,
313317
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention)>,
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if defined(_MSC_VER)
5+
#pragma warning(push)
6+
#pragma warning(disable : 4100)
7+
#pragma warning(disable : 4244)
8+
#pragma warning(disable : 4200)
9+
#endif
10+
11+
#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
12+
13+
#if defined(_MSC_VER)
14+
#pragma warning(pop)
15+
#endif
16+
17+
namespace ort_fastertransformer {
18+
template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>;
19+
} // namespace ort_fastertransformer
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if defined(_MSC_VER)
5+
#pragma warning(push)
6+
#pragma warning(disable : 4100)
7+
#pragma warning(disable : 4244)
8+
#pragma warning(disable : 4200)
9+
#endif
10+
11+
#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
12+
13+
#if defined(_MSC_VER)
14+
#pragma warning(pop)
15+
#endif
16+
17+
namespace ort_fastertransformer {
18+
template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>;
19+
} // namespace ort_fastertransformer
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if defined(_MSC_VER)
5+
#pragma warning(push)
6+
#pragma warning(disable : 4100)
7+
#pragma warning(disable : 4244)
8+
#pragma warning(disable : 4200)
9+
#endif
10+
11+
#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h"
12+
13+
#if defined(_MSC_VER)
14+
#pragma warning(pop)
15+
#endif
16+
17+
namespace ort_fastertransformer {
18+
template class MoeGemmRunner<__nv_bfloat16, uint8_t>;
19+
} // namespace ort_fastertransformer

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
#include "cutlass_heuristic.h"
5454
#include "moe_gemm_kernels.h"
5555

56+
#include <cuda_bf16.h>
57+
5658
#include <limits>
5759
#include <math.h>
5860
#include <sstream>
@@ -66,8 +68,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w
6668
int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts,
6769
CutlassGemmConfig gemm_config, const int multi_processor_count,
6870
cudaStream_t stream, int* kernel_occupancy = nullptr) {
69-
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value,
70-
"Specialized for half, float");
71+
static_assert(cutlass::platform::is_same<T, half>::value || cutlass::platform::is_same<T, float>::value || cutlass::platform::is_same<T, __nv_bfloat16>::value,
72+
"Specialized for half, float, bfloat16");
7173

7274
static_assert(cutlass::platform::is_same<T, WeightType>::value ||
7375
cutlass::platform::is_same<WeightType, uint8_t>::value ||
@@ -76,12 +78,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w
7678

7779
// The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary.
7880
using ElementType_ =
79-
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, T>::type;
81+
typename cutlass::platform::conditional<cutlass::platform::is_same<T, half>::value, cutlass::half_t, typename cutlass::platform::conditional<cutlass::platform::is_same<T, __nv_bfloat16>::value, cutlass::bfloat16_t, T>::type>::type;
8082
using ElementType = ElementType_;
8183

8284
using CutlassWeightType_ =
83-
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t,
84-
WeightType>::type;
85+
typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, half>::value, cutlass::half_t, typename cutlass::platform::conditional<cutlass::platform::is_same<WeightType, __nv_bfloat16>::value, cutlass::bfloat16_t, WeightType>::type>::type;
8586

8687
using CutlassWeightType = CutlassWeightType_;
8788

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,13 @@
3838

3939
#include "moe_kernel.h"
4040

41+
#include <cuda_runtime_api.h>
4142
#include <cub/cub.cuh>
4243
#include <cub/device/device_radix_sort.cuh>
4344
#include <cub/util_type.cuh>
4445

46+
#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
47+
4548
namespace ort_fastertransformer {
4649
static constexpr int WARP_SIZE = 32;
4750

@@ -103,11 +106,16 @@ void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows
103106
dim3 block(std::min(intermediate_size, 1024));
104107
dim3 grid(num_rows);
105108

109+
DUMP_TENSOR_INIT();
110+
DUMP_TENSOR("swiglu input", input, num_rows, 2 * intermediate_size);
111+
106112
if constexpr (interleaved) {
107113
swiglu_kernel_interleaved<T><<<grid, block, 0, stream>>>(output, input, intermediate_size, num_rows, swiglu_alpha);
108114
} else {
109115
swiglu_kernel_chunked<T><<<grid, block, 0, stream>>>(output, input, intermediate_size, num_rows, swiglu_alpha);
110116
}
117+
118+
DUMP_TENSOR("swiglu output", output, num_rows, intermediate_size);
111119
}
112120

113121
// ====================== Softmax things ===============================
@@ -838,11 +846,15 @@ void CutlassMoeFCRunner<T, WeightType, Enable>::configure_ws_ptrs(char* ws_ptr,
838846
}
839847

840848
namespace {
841-
842-
struct __align__(8) Half4 {
849+
typedef struct __CUDA_ALIGN__(8) {
843850
half2 x;
844851
half2 y;
845-
};
852+
} half2_2;
853+
854+
typedef struct __CUDA_ALIGN__(8) {
855+
__nv_bfloat162 x;
856+
__nv_bfloat162 y;
857+
} __nv_bfloat162_2;
846858

847859
// TODO(wy): move to common header
848860
template <typename T>
@@ -853,7 +865,11 @@ struct T4<float> {
853865
};
854866
template <>
855867
struct T4<half> {
856-
using Type = Half4;
868+
using Type = half2_2;
869+
};
870+
template <>
871+
struct T4<__nv_bfloat16> {
872+
using Type = __nv_bfloat162_2;
857873
};
858874

859875
template <typename T>
@@ -866,6 +882,10 @@ template <>
866882
struct T2<half> {
867883
using Type = half2;
868884
};
885+
template <>
886+
struct T2<__nv_bfloat16> {
887+
using Type = __nv_bfloat162;
888+
};
869889

870890
inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
871891

@@ -882,15 +902,27 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha
882902
#endif
883903

884904
// TODO(wy): use cuda common header and investigate pipeline build issue.
885-
inline __device__ Half4 operator*(const Half4 a, const Half4 b) {
905+
inline __device__ half2_2 operator*(const half2_2 a, const half2_2 b) {
886906
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \
887907
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
888-
Half4 result;
908+
half2_2 result;
889909
result.x = a.x * b.x;
890910
result.y = a.y * b.y;
891911
return result;
892912
#else
893-
return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
913+
return half2_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
914+
#endif
915+
}
916+
917+
inline __device__ __nv_bfloat162_2 operator*(const __nv_bfloat162_2 a, const __nv_bfloat162_2 b) {
918+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && \
919+
((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2)))
920+
__nv_bfloat162_2 result;
921+
result.x = a.x * b.x;
922+
result.y = a.y * b.y;
923+
return result;
924+
#else
925+
return __nv_bfloat162_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)};
894926
#endif
895927
}
896928

@@ -1291,18 +1323,26 @@ template void topk_gating_softmax_kernelLauncher(const float*, const bool*, floa
12911323
int, bool, bool, cudaStream_t);
12921324
template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int,
12931325
int, bool, bool, cudaStream_t);
1326+
template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, const bool*, __nv_bfloat16*, __nv_bfloat16*, int*, int*, int, int,
1327+
int, bool, bool, cudaStream_t);
12941328

12951329
// ==================== Variable batched GEMM specializations ==================================
12961330
template class CutlassMoeFCRunner<float, float>;
12971331
template class CutlassMoeFCRunner<half, half>;
1332+
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>;
1333+
// For qMoE:
12981334
template class CutlassMoeFCRunner<half, cutlass::uint4b_t>;
12991335
template class CutlassMoeFCRunner<half, uint8_t>;
1336+
template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>;
1337+
template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>;
13001338

13011339
// ===================== Specializations for init routing =========================
13021340
template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int,
13031341
cudaStream_t);
13041342
template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int,
13051343
cudaStream_t);
1344+
template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const int*, int*, int, int, int, int,
1345+
cudaStream_t);
13061346

13071347
// ==================== Specializations for final routing ===================================
13081348
template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*,
@@ -1317,6 +1357,8 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl
13171357
const float*, const int*, const int*, int, int, int, cudaStream_t);
13181358
template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*,
13191359
const half*, const int*, const int*, int, int, int, cudaStream_t);
1360+
template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*,
1361+
const __nv_bfloat16*, const int*, const int*, int, int, int, cudaStream_t);
13201362

13211363
template void invokeSwiGLU<float, true>(float*, float const*, int, int, float, cudaStream_t);
13221364
template void invokeSwiGLU<half, true>(half*, half const*, int, int, float, cudaStream_t);

onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,4 @@ class CutlassMoeFCRunner {
178178
std::vector<int64_t> total_rows_before_expert_host_;
179179
};
180180

181-
template <typename WeightType>
182-
class CutlassMoeFCRunner<float, WeightType, typename std::enable_if_t<!std::is_same<float, WeightType>::value>> {
183-
public:
184-
CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer);
185-
186-
size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) {
187-
return 0;
188-
}
189-
};
190-
191181
} // namespace ort_fastertransformer

0 commit comments

Comments
 (0)