@@ -794,6 +794,39 @@ void LaunchAddBiasTranspose(
794
794
}
795
795
}
796
796
797
+ template <>
798
+ void LaunchAddBiasTranspose<BFloat16>(
799
+ cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
800
+ const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
801
+ const BFloat16* input, const BFloat16* biases, BFloat16* output,
802
+ bool /* enable_half4*/ , const int v_head_size,
803
+ BFloat16* qkv_add_bias, int total_matrix_count,
804
+ bool do_rotary, int rotary_embedding, int past_sequence_length) {
805
+ total_matrix_count = std::max (num_matrices, total_matrix_count);
806
+ if (0 == (qk_head_size & 1 ) && (v_head_size == -1 || 0 == (v_head_size & 1 )) && !do_rotary) {
807
+ const int H = qk_head_size / 2 ;
808
+ const int H_v = v_head_size / 2 ;
809
+
810
+ const __nv_bfloat162* input2 = reinterpret_cast <const __nv_bfloat162*>(input);
811
+ const __nv_bfloat162* biases2 = reinterpret_cast <const __nv_bfloat162*>(biases);
812
+ __nv_bfloat162* output2 = reinterpret_cast <__nv_bfloat162*>(output);
813
+ __nv_bfloat162* qkv_add_bias2 = reinterpret_cast <__nv_bfloat162*>(qkv_add_bias);
814
+
815
+ InvokeAddBiasTranspose<__nv_bfloat162>(
816
+ stream, num_matrices, format, max_threads_per_block,
817
+ batch_size, sequence_length, num_heads, H,
818
+ input2, biases2, output2, qkv_add_bias2,
819
+ H_v, total_matrix_count);
820
+ } else {
821
+ InvokeAddBiasTranspose<BFloat16>(
822
+ stream, num_matrices, format, max_threads_per_block,
823
+ batch_size, sequence_length, num_heads, qk_head_size,
824
+ input, biases, output,
825
+ qkv_add_bias, v_head_size, total_matrix_count,
826
+ do_rotary, rotary_embedding, past_sequence_length);
827
+ }
828
+ }
829
+
797
830
template <>
798
831
void LaunchAddBiasTranspose (
799
832
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
@@ -888,6 +921,20 @@ void LaunchAddBiasTransposeTrt(
888
921
ORT_ENFORCE (false , " Shall not call this since fused kernel does not support float input." );
889
922
}
890
923
924
+ template <>
925
+ void LaunchAddBiasTransposeTrt<BFloat16>(
926
+ cudaStream_t /* stream*/ , const int /* max_threads_per_block*/ ,
927
+ const int /* batch_size*/ , const int /* sequence_length*/ ,
928
+ const int /* num_heads*/ , const int /* head_size*/ ,
929
+ const BFloat16* /* biases*/ ,
930
+ const BFloat16* /* query*/ ,
931
+ const BFloat16* /* key*/ ,
932
+ const BFloat16* /* value*/ ,
933
+ BFloat16* /* output*/ ,
934
+ bool /* is_cross_attention*/ , int /* kv_sequence_length*/ ) {
935
+ ORT_ENFORCE (false , " BF16 not supported for LaunchAddBiasTransposeTrt." );
936
+ }
937
+
891
938
template <>
892
939
void LaunchAddBiasTransposeTrt (
893
940
cudaStream_t stream, const int max_threads_per_block,
@@ -1049,6 +1096,38 @@ void LaunchAddBias(
1049
1096
}
1050
1097
}
1051
1098
1099
+ template <>
1100
+ void LaunchAddBias<BFloat16>(
1101
+ cudaStream_t stream, const int max_threads_per_block,
1102
+ const int batch_size, const int sequence_length, const int kv_sequence_length,
1103
+ const int num_heads, const int head_size, const int v_head_size,
1104
+ const BFloat16* biases, const BFloat16* query, const BFloat16* key, const BFloat16* value,
1105
+ BFloat16* q, BFloat16* k, BFloat16* v) {
1106
+ if (0 == (head_size & 1 ) && 0 == (v_head_size & 1 )) {
1107
+ const int H = head_size / 2 ;
1108
+ const int H_v = v_head_size / 2 ;
1109
+ const __nv_bfloat162* query2 = reinterpret_cast <const __nv_bfloat162*>(query);
1110
+ const __nv_bfloat162* key2 = reinterpret_cast <const __nv_bfloat162*>(key);
1111
+ const __nv_bfloat162* value2 = reinterpret_cast <const __nv_bfloat162*>(value);
1112
+ const __nv_bfloat162* biases2 = reinterpret_cast <const __nv_bfloat162*>(biases);
1113
+ __nv_bfloat162* q2 = reinterpret_cast <__nv_bfloat162*>(q);
1114
+ __nv_bfloat162* k2 = reinterpret_cast <__nv_bfloat162*>(k);
1115
+ __nv_bfloat162* v2 = reinterpret_cast <__nv_bfloat162*>(v);
1116
+
1117
+ InvokeAddBias<__nv_bfloat162>(
1118
+ stream, max_threads_per_block,
1119
+ batch_size, sequence_length, kv_sequence_length, num_heads, H, H_v,
1120
+ biases2, query2, key2, value2, q2, k2, v2);
1121
+
1122
+ } else {
1123
+ InvokeAddBias<BFloat16>(
1124
+ stream, max_threads_per_block,
1125
+ batch_size, sequence_length, kv_sequence_length, num_heads,
1126
+ head_size, v_head_size,
1127
+ biases, query, key, value, q, k, v);
1128
+ }
1129
+ }
1130
+
1052
1131
template <typename T>
1053
1132
void InvokeAddBias (
1054
1133
cudaStream_t stream, const int max_threads_per_block,
@@ -1125,6 +1204,31 @@ void LaunchAddBias(
1125
1204
}
1126
1205
}
1127
1206
1207
+ template <>
1208
+ void LaunchAddBias<BFloat16>(
1209
+ cudaStream_t stream, const int max_threads_per_block,
1210
+ const int batch_size, const int sequence_length,
1211
+ const int num_heads, const int head_size,
1212
+ const BFloat16* biases, const BFloat16* query, BFloat16* q) {
1213
+ if (0 == (head_size & 1 )) {
1214
+ const int H = head_size / 2 ;
1215
+ const __nv_bfloat162* query2 = reinterpret_cast <const __nv_bfloat162*>(query);
1216
+ const __nv_bfloat162* biases2 = reinterpret_cast <const __nv_bfloat162*>(biases);
1217
+ __nv_bfloat162* q2 = reinterpret_cast <__nv_bfloat162*>(q);
1218
+
1219
+ InvokeAddBias<__nv_bfloat162>(
1220
+ stream, max_threads_per_block,
1221
+ batch_size, sequence_length, num_heads, H,
1222
+ biases2, query2, q2);
1223
+
1224
+ } else {
1225
+ InvokeAddBias<BFloat16>(
1226
+ stream, max_threads_per_block,
1227
+ batch_size, sequence_length, num_heads, head_size,
1228
+ biases, query, q);
1229
+ }
1230
+ }
1231
+
1128
1232
} // namespace cuda
1129
1233
} // namespace contrib
1130
1234
} // namespace onnxruntime
0 commit comments