Skip to content

Commit 8301eea

Browse files
Attention CUDA BFloat16 Support (#25974)
### Description Attention BFloat16 Support for CUDA - extends kernel implementations to accept BF16 input/output tensors. ### Motivation and Context We already have BFloat16 support for GQA (Group Query Attention), but not for regular Attention which many models require for inference (e.g. visual encoder of Gemma 3) due to FP32-like stability at lower memory/compute cost. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d530b29 commit 8301eea

23 files changed

+529
-34
lines changed

docs/ContribOperators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ This version of the operator has been available since version 1 of the 'com.micr
199199
#### Type Constraints
200200

201201
<dl>
202-
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
202+
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
203203
<dd>Constrain input and output types to float tensors.</dd>
204204
<dt><tt>M</tt> : tensor(int32)</dt>
205205
<dd>Constrain mask index to integer types</dd>

docs/OperatorKernels.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ Do not modify directly.*
940940
| |
941941
| |
942942
|**Operator Domain:** *com.microsoft*||||
943-
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* attention_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
943+
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* attention_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
944944
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
945945
|BiasAdd|*in* X:**T**<br> *in* bias:**T**<br> *in* skip:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
946946
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|

onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,39 @@ void LaunchAddBiasTranspose(
794794
}
795795
}
796796

797+
template <>
798+
void LaunchAddBiasTranspose<BFloat16>(
799+
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
800+
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
801+
const BFloat16* input, const BFloat16* biases, BFloat16* output,
802+
bool /*enable_half4*/, const int v_head_size,
803+
BFloat16* qkv_add_bias, int total_matrix_count,
804+
bool do_rotary, int rotary_embedding, int past_sequence_length) {
805+
total_matrix_count = std::max(num_matrices, total_matrix_count);
806+
if (0 == (qk_head_size & 1) && (v_head_size == -1 || 0 == (v_head_size & 1)) && !do_rotary) {
807+
const int H = qk_head_size / 2;
808+
const int H_v = v_head_size / 2;
809+
810+
const __nv_bfloat162* input2 = reinterpret_cast<const __nv_bfloat162*>(input);
811+
const __nv_bfloat162* biases2 = reinterpret_cast<const __nv_bfloat162*>(biases);
812+
__nv_bfloat162* output2 = reinterpret_cast<__nv_bfloat162*>(output);
813+
__nv_bfloat162* qkv_add_bias2 = reinterpret_cast<__nv_bfloat162*>(qkv_add_bias);
814+
815+
InvokeAddBiasTranspose<__nv_bfloat162>(
816+
stream, num_matrices, format, max_threads_per_block,
817+
batch_size, sequence_length, num_heads, H,
818+
input2, biases2, output2, qkv_add_bias2,
819+
H_v, total_matrix_count);
820+
} else {
821+
InvokeAddBiasTranspose<BFloat16>(
822+
stream, num_matrices, format, max_threads_per_block,
823+
batch_size, sequence_length, num_heads, qk_head_size,
824+
input, biases, output,
825+
qkv_add_bias, v_head_size, total_matrix_count,
826+
do_rotary, rotary_embedding, past_sequence_length);
827+
}
828+
}
829+
797830
template <>
798831
void LaunchAddBiasTranspose(
799832
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
@@ -888,6 +921,20 @@ void LaunchAddBiasTransposeTrt(
888921
ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input.");
889922
}
890923

924+
template <>
925+
void LaunchAddBiasTransposeTrt<BFloat16>(
926+
cudaStream_t /*stream*/, const int /*max_threads_per_block*/,
927+
const int /*batch_size*/, const int /*sequence_length*/,
928+
const int /*num_heads*/, const int /*head_size*/,
929+
const BFloat16* /*biases*/,
930+
const BFloat16* /*query*/,
931+
const BFloat16* /*key*/,
932+
const BFloat16* /*value*/,
933+
BFloat16* /*output*/,
934+
bool /*is_cross_attention*/, int /*kv_sequence_length*/) {
935+
ORT_ENFORCE(false, "BF16 not supported for LaunchAddBiasTransposeTrt.");
936+
}
937+
891938
template <>
892939
void LaunchAddBiasTransposeTrt(
893940
cudaStream_t stream, const int max_threads_per_block,
@@ -1049,6 +1096,38 @@ void LaunchAddBias(
10491096
}
10501097
}
10511098

1099+
template <>
1100+
void LaunchAddBias<BFloat16>(
1101+
cudaStream_t stream, const int max_threads_per_block,
1102+
const int batch_size, const int sequence_length, const int kv_sequence_length,
1103+
const int num_heads, const int head_size, const int v_head_size,
1104+
const BFloat16* biases, const BFloat16* query, const BFloat16* key, const BFloat16* value,
1105+
BFloat16* q, BFloat16* k, BFloat16* v) {
1106+
if (0 == (head_size & 1) && 0 == (v_head_size & 1)) {
1107+
const int H = head_size / 2;
1108+
const int H_v = v_head_size / 2;
1109+
const __nv_bfloat162* query2 = reinterpret_cast<const __nv_bfloat162*>(query);
1110+
const __nv_bfloat162* key2 = reinterpret_cast<const __nv_bfloat162*>(key);
1111+
const __nv_bfloat162* value2 = reinterpret_cast<const __nv_bfloat162*>(value);
1112+
const __nv_bfloat162* biases2 = reinterpret_cast<const __nv_bfloat162*>(biases);
1113+
__nv_bfloat162* q2 = reinterpret_cast<__nv_bfloat162*>(q);
1114+
__nv_bfloat162* k2 = reinterpret_cast<__nv_bfloat162*>(k);
1115+
__nv_bfloat162* v2 = reinterpret_cast<__nv_bfloat162*>(v);
1116+
1117+
InvokeAddBias<__nv_bfloat162>(
1118+
stream, max_threads_per_block,
1119+
batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v,
1120+
biases2, query2, key2, value2, q2, k2, v2);
1121+
1122+
} else {
1123+
InvokeAddBias<BFloat16>(
1124+
stream, max_threads_per_block,
1125+
batch_size, sequence_length, kv_sequence_length, num_heads,
1126+
head_size, v_head_size,
1127+
biases, query, key, value, q, k, v);
1128+
}
1129+
}
1130+
10521131
template <typename T>
10531132
void InvokeAddBias(
10541133
cudaStream_t stream, const int max_threads_per_block,
@@ -1125,6 +1204,31 @@ void LaunchAddBias(
11251204
}
11261205
}
11271206

1207+
template <>
1208+
void LaunchAddBias<BFloat16>(
1209+
cudaStream_t stream, const int max_threads_per_block,
1210+
const int batch_size, const int sequence_length,
1211+
const int num_heads, const int head_size,
1212+
const BFloat16* biases, const BFloat16* query, BFloat16* q) {
1213+
if (0 == (head_size & 1)) {
1214+
const int H = head_size / 2;
1215+
const __nv_bfloat162* query2 = reinterpret_cast<const __nv_bfloat162*>(query);
1216+
const __nv_bfloat162* biases2 = reinterpret_cast<const __nv_bfloat162*>(biases);
1217+
__nv_bfloat162* q2 = reinterpret_cast<__nv_bfloat162*>(q);
1218+
1219+
InvokeAddBias<__nv_bfloat162>(
1220+
stream, max_threads_per_block,
1221+
batch_size, sequence_length, num_heads, H,
1222+
biases2, query2, q2);
1223+
1224+
} else {
1225+
InvokeAddBias<BFloat16>(
1226+
stream, max_threads_per_block,
1227+
batch_size, sequence_length, num_heads, head_size,
1228+
biases, query, q);
1229+
}
1230+
}
1231+
11281232
} // namespace cuda
11291233
} // namespace contrib
11301234
} // namespace onnxruntime

onnxruntime/contrib_ops/cuda/bert/attention.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,24 @@ constexpr int kPresentOutputIndex = 1;
3636

3737
REGISTER_KERNEL_TYPED(float)
3838
REGISTER_KERNEL_TYPED(MLFloat16)
39+
REGISTER_KERNEL_TYPED(BFloat16)
3940

4041
template <typename T>
4142
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
4243
kernel_options_ = this->GetAttentionKernelOptions();
4344

44-
disable_fused_self_attention_ = sizeof(T) != 2 || !kernel_options_->UseTrtFusedAttention();
45+
constexpr bool kIsFp16 = std::is_same<T, MLFloat16>::value;
46+
constexpr bool kIsBf16 = std::is_same<T, BFloat16>::value;
47+
constexpr bool kIs16bit = kIsFp16 || kIsBf16;
4548

46-
enable_trt_flash_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtFlashAttention();
49+
// We only support FP16 for TRT fused/flash/causal attention.
50+
disable_fused_self_attention_ = !kIsFp16 || !kernel_options_->UseTrtFusedAttention();
51+
enable_trt_flash_attention_ = kIsFp16 && kernel_options_->UseTrtFlashAttention();
52+
enable_fused_causal_attention_ = kIsFp16 && kernel_options_->UseTrtCausalAttention();
4753

48-
enable_fused_causal_attention_ = sizeof(T) == 2 && kernel_options_->UseTrtCausalAttention();
54+
disable_memory_efficient_attention_ = kIsBf16 || !kernel_options_->UseEfficientAttention();
4955

50-
disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention();
51-
52-
disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention();
56+
disable_flash_attention_ = !kIs16bit || !kernel_options_->UseFlashAttention();
5357
}
5458

5559
template <typename T>

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,13 @@ Status QkvToContext(
952952
Stream* ort_stream,
953953
contrib::AttentionParameters& parameters,
954954
AttentionData<T>& data) {
955+
if constexpr (std::is_same<T, BFloat16>::value || std::is_same<QK, BFloat16>::value) {
956+
if (device_prop.major < 8) {
957+
ORT_THROW("BF16 Attention requires Ampere (sm_80)+ with BF16 support. This GPU (",
958+
device_prop.name, ", cc ", device_prop.major, ".", device_prop.minor, ") is not supported.");
959+
}
960+
}
961+
955962
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
956963
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
957964
const int batch_size = parameters.batch_size;
@@ -1040,6 +1047,8 @@ template struct AttentionData<float>;
10401047

10411048
template struct AttentionData<half>;
10421049

1050+
template struct AttentionData<BFloat16>;
1051+
10431052
template Status QkvToContext<float>(
10441053
const cudaDeviceProp& device_prop,
10451054
cublasHandle_t& cublas,
@@ -1056,6 +1065,14 @@ template Status QkvToContext<half>(
10561065
contrib::AttentionParameters& parameters,
10571066
AttentionData<half>& data);
10581067

1068+
template Status QkvToContext<BFloat16>(
1069+
const cudaDeviceProp& device_prop,
1070+
cublasHandle_t& cublas,
1071+
cudnnHandle_t& cudnn,
1072+
Stream* ort_stream,
1073+
contrib::AttentionParameters& parameters,
1074+
AttentionData<BFloat16>& data);
1075+
10591076
template Status QkvToContext<float, half>(
10601077
const cudaDeviceProp& device_prop,
10611078
cublasHandle_t& cublas,

onnxruntime/contrib_ops/cuda/bert/attention_impl.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <cuda_bf16.h>
67
#include <cuda_fp16.h>
78
#include <cublas_v2.h>
89
#include <gsl/gsl>
@@ -96,6 +97,10 @@ Status LaunchTransCtx(cudaStream_t stream,
9697
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
9798
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output);
9899

100+
Status LaunchTransCtx(cudaStream_t stream,
101+
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
102+
const int max_threads_per_block, const bool reversed_bs, const BFloat16* input, BFloat16* output);
103+
99104
// BxSxMxNxH or SxBxMxNxH (reversed_bs is true) => MxBxNxSxH
100105
Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
101106
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
@@ -107,12 +112,20 @@ Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
107112
const int max_threads_per_block, const bool reversed_bs, const half* input, half* output,
108113
int total_matrix_count = -1);
109114

115+
Status LaunchTransQkv(cudaStream_t stream, const int matrix_num,
116+
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
117+
const int max_threads_per_block, const bool reversed_bs, const BFloat16* input, BFloat16* output,
118+
int total_matrix_count = -1);
119+
110120
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
111121
const float* input, float* output, cudaStream_t stream, const int max_threads_per_block);
112122

113123
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
114124
const half* input, half* output, cudaStream_t stream, const int max_threads_per_block);
115125

126+
Status Transpose_BSNH_to_BNSH(const int batch_size, const int sequence_length, const int num_heads, const int head_size,
127+
const BFloat16* input, BFloat16* output, cudaStream_t stream, const int max_threads_per_block);
128+
116129
template <typename T>
117130
Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size,
118131
int sequence_length, int total_sequence_length,

onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,59 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
197197
return CUDA_CALL(cudaGetLastError());
198198
}
199199

200+
Status LaunchConcatTensorToTensor(cudaStream_t stream,
201+
const int all_sequence_length,
202+
const int sequence_length,
203+
const int batch_size,
204+
const int head_size,
205+
const int num_heads,
206+
const int max_threads_per_block,
207+
const int matrix_num,
208+
const BFloat16* tensor_in,
209+
const BFloat16* tensor_add,
210+
BFloat16* tensor_out) {
211+
assert(num_heads <= max_threads_per_block);
212+
const dim3 grid(all_sequence_length, batch_size, matrix_num);
213+
if (0 == (head_size & 1)) {
214+
const int H = head_size / 2;
215+
if (H * num_heads <= max_threads_per_block) {
216+
const dim3 block(H, num_heads, 1);
217+
ConcatTensorToTensor<__nv_bfloat162><<<grid, block, 0, stream>>>(
218+
sequence_length,
219+
reinterpret_cast<const __nv_bfloat162*>(tensor_in),
220+
reinterpret_cast<const __nv_bfloat162*>(tensor_add),
221+
reinterpret_cast<__nv_bfloat162*>(tensor_out));
222+
} else {
223+
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
224+
ConcatTensorToTensorLarge<__nv_bfloat162><<<grid, block, 0, stream>>>(
225+
sequence_length,
226+
H,
227+
reinterpret_cast<const __nv_bfloat162*>(tensor_in),
228+
reinterpret_cast<const __nv_bfloat162*>(tensor_add),
229+
reinterpret_cast<__nv_bfloat162*>(tensor_out));
230+
}
231+
} else {
232+
if (head_size * num_heads <= max_threads_per_block) {
233+
const dim3 block(head_size, num_heads, 1);
234+
ConcatTensorToTensor<__nv_bfloat16><<<grid, block, 0, stream>>>(
235+
sequence_length,
236+
reinterpret_cast<const __nv_bfloat16*>(tensor_in),
237+
reinterpret_cast<const __nv_bfloat16*>(tensor_add),
238+
reinterpret_cast<__nv_bfloat16*>(tensor_out));
239+
} else {
240+
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
241+
ConcatTensorToTensorLarge<__nv_bfloat16><<<grid, block, 0, stream>>>(
242+
sequence_length,
243+
head_size,
244+
reinterpret_cast<const __nv_bfloat16*>(tensor_in),
245+
reinterpret_cast<const __nv_bfloat16*>(tensor_add),
246+
reinterpret_cast<__nv_bfloat16*>(tensor_out));
247+
}
248+
}
249+
250+
return CUDA_CALL(cudaGetLastError());
251+
}
252+
200253
#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP
201254

202255
// ----------------------------------------------------------------------------------
@@ -332,6 +385,18 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
332385
const half* bias,
333386
const half* qkv_buffer,
334387
half* present);
388+
389+
template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
390+
const int max_sequence_length,
391+
const int total_sequence_length,
392+
const int sequence_length,
393+
const int batch_size,
394+
const int head_size,
395+
const int num_heads,
396+
const int max_threads_per_block,
397+
const BFloat16* bias,
398+
const BFloat16* qkv_buffer,
399+
BFloat16* present);
335400
#endif
336401

337402
// Kernel to append new and past kv in either BSNH or BNSH format

onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55
#include "core/providers/cuda/shared_inc/cuda_utils.h"
6+
#include <cuda_bf16.h>
67
#include <cuda_fp16.h>
78
#include "core/framework/allocator.h"
89
#include "core/providers/cuda/cuda_common.h"
@@ -38,6 +39,18 @@ Status LaunchConcatTensorToTensor(cudaStream_t stream,
3839
const half* tensor_add,
3940
half* tensor_out);
4041

42+
Status LaunchConcatTensorToTensor(cudaStream_t stream,
43+
const int all_sequence_length,
44+
const int sequence_length,
45+
const int batch_size,
46+
const int head_size,
47+
const int num_heads,
48+
const int max_threads_per_block,
49+
const int matrix_num,
50+
const BFloat16* tensor_in,
51+
const BFloat16* tensor_add,
52+
BFloat16* tensor_out);
53+
4154
template <typename T>
4255
Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
4356
const int max_sequence_length,

onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,9 +744,11 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
744744
#endif
745745

746746
if (nullptr != data.gemm_buffer) { // Attention operator
747-
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(parameters, data, stream, max_threads_per_block));
747+
ORT_RETURN_IF_ERROR(PrepareQkv_Attention<T>(
748+
parameters, data, stream, max_threads_per_block));
748749
} else { // MultiHeadAttention operator
749-
ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention<T>(parameters, data, stream, max_threads_per_block));
750+
ORT_RETURN_IF_ERROR(PrepareQkv_MultiHeadAttention<T>(
751+
parameters, data, stream, max_threads_per_block));
750752
}
751753

752754
assert(data.qkv_format != AttentionQkvFormat::UNKNOWN);
@@ -776,6 +778,12 @@ template Status PrepareQkv<half>(
776778
cudaStream_t stream,
777779
int max_threads_per_block);
778780

781+
template Status PrepareQkv<BFloat16>(
782+
contrib::AttentionParameters& parameters,
783+
AttentionData<BFloat16>& data,
784+
cudaStream_t stream,
785+
int max_threads_per_block);
786+
779787
} // namespace cuda
780788
} // namespace contrib
781789
} // namespace onnxruntime

0 commit comments

Comments
 (0)