10
10
#include < cuda_fp16.h>
11
11
#include < cuda_bf16.h>
12
12
#include < iostream>
13
+ #include < stdint.h>
13
14
14
15
#define TEST (FN ) \
15
16
{ \
@@ -478,6 +479,119 @@ __global__ void mma_kernel_m16n8k16_ptx_f16_f32(half *A, half *B, float *C, floa
478
479
}
479
480
}
480
481
482
+
483
+ template <int Shape_M, int Shape_N, int Shape_K>
484
+ __global__ void mma_kernel_m16n8k16_ptx_bf16_f32 (__nv_bfloat16 *A, __nv_bfloat16 *B, float *C, float *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
485
+ const int THREAD_IDX = threadIdx .x + blockIdx .x * blockDim .x ;
486
+ const int WARP_ID = THREAD_IDX / WARP_SIZE;
487
+ const int LANE_ID = THREAD_IDX % WARP_SIZE;
488
+
489
+ const int THREAD_ROW = LANE_ID / 4 ;
490
+ const int THREAD_COL = LANE_ID % 4 ;
491
+
492
+ int A_OFFSET = (WARP_ID % A_NUM_MAT) * Shape_M * Shape_K;
493
+ int B_OFFSET = (WARP_ID % B_NUM_MAT) * Shape_K * Shape_N;
494
+ int CD_OFFSET = (WARP_ID % CD_NUM_MAT) * Shape_M * Shape_N;
495
+
496
+ uint32_t a[4 ];
497
+ uint32_t b[2 ];
498
+ float c[4 ];
499
+ float d[4 ];
500
+
501
+ unsigned char indx = 0 ;
502
+ for (unsigned char i = 0 ; i < 8 ; i++) {
503
+ int r_off = 8 ;
504
+ if (i < 2 || (i >= 4 && i < 6 )) {
505
+ r_off = 0 ;
506
+ }
507
+
508
+ int c_off = 0 ;
509
+ if (i >= 4 ) {
510
+ c_off = 8 ;
511
+ }
512
+
513
+ int load_offset =
514
+ A_OFFSET + OFFSET (THREAD_ROW + r_off,
515
+ (THREAD_COL * 2 ) + (i & 0x1 ) + c_off, Shape_K);
516
+ if (IN_BOUND_A (load_offset)) {
517
+ __nv_bfloat16 val = A[load_offset];
518
+ if ((i & 0x01 ) == 0 ) {
519
+ // First value of pair, put to high
520
+ a[indx] = uint32_t (*reinterpret_cast <uint16_t *>(&val) << 16 );
521
+ } else {
522
+ // Second value of pair, put to low
523
+ a[indx] = a[indx] | *reinterpret_cast <uint16_t *>(&val);
524
+ indx++;
525
+ }
526
+ }
527
+ }
528
+
529
+ indx = 0 ;
530
+ for (int i = 0 ; i < 4 ; i++) {
531
+ int r_off = 0 ;
532
+ if (i >= 2 ) {
533
+ r_off = 8 ;
534
+ }
535
+
536
+ int load_offset = B_OFFSET + OFFSET ((THREAD_COL * 2 ) + (i & 0x1 ) + r_off,
537
+ THREAD_ROW, Shape_N);
538
+ if (IN_BOUND_B (load_offset)) {
539
+ __nv_bfloat16 val = B[load_offset];
540
+ if ((i & 0x01 ) == 0 ) {
541
+ // First value of pair, put to high
542
+ b[indx] = uint32_t (*reinterpret_cast <uint16_t *>(&val) << 16 );
543
+ } else {
544
+ // Second value of pair, put to low
545
+ b[indx] = b[indx] | *reinterpret_cast <uint16_t *>(&val);
546
+ indx++;
547
+ }
548
+ }
549
+ }
550
+
551
+ for (int i = 0 ; i < 4 ; i++) {
552
+ int load_offset;
553
+
554
+ if (i < 2 ) {
555
+ load_offset =
556
+ CD_OFFSET + OFFSET (THREAD_ROW, (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
557
+ } else {
558
+ load_offset = CD_OFFSET + OFFSET (THREAD_ROW + 8 ,
559
+ (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
560
+ }
561
+
562
+ if (IN_BOUND_CD (load_offset)) {
563
+ c[i] = C[load_offset];
564
+ }
565
+ }
566
+
567
+ asm (" mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
568
+ " { %0, %1, %2, %3 }, "
569
+ " { %4, %5, %6, %7 }, "
570
+ " { %8, %9 }, "
571
+ " { %10, %11, %12, %13 };"
572
+ : " +f" (d[0 ]), " +f" (d[1 ]), " +f" (d[2 ]), " +f" (d[3 ])
573
+ : " r" (a[0 ]), " r" (a[1 ]), " r" (a[2 ]), " r" (a[3 ]),
574
+ " r" (b[0 ]), " r" (b[1 ]),
575
+ " f" (c[0 ]), " f" (c[1 ]), " f" (c[2 ]), " f" (c[3 ]));
576
+
577
+
578
+ for (int i = 0 ; i < 4 ; i++) {
579
+ int load_offset;
580
+
581
+ if (i < 2 ) {
582
+ load_offset =
583
+ CD_OFFSET + OFFSET (THREAD_ROW, (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
584
+ } else {
585
+ load_offset = CD_OFFSET + OFFSET (THREAD_ROW + 8 ,
586
+ (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
587
+ }
588
+
589
+ if (IN_BOUND_CD (load_offset)) {
590
+ D[load_offset] = d[i];
591
+ }
592
+ }
593
+ }
594
+
481
595
template <int Shape_M, int Shape_N, int Shape_K>
482
596
__global__ void mma_kernel_m16n8k16_s8_s32 (int8_t *A, int8_t *B, int *C, int *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
483
597
const int THREAD_IDX = threadIdx .x + blockIdx .x * blockDim .x ;
@@ -865,6 +979,65 @@ bool run_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
865
979
return correct;
866
980
}
867
981
982
+
983
+ bool run_test_mma_m16n8k16_bf16_f32 (const int M, const int N, const int K) {
984
+ int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
985
+ calculate_num_matrices<16 , 8 , 16 >(M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
986
+
987
+ if (A_NUM_MAT == 0 || B_NUM_MAT == 0 || CD_NUM_MAT == 0 ) {
988
+ std::cerr << " Matrix dimensions are not compatible with m16n8k16"
989
+ << std::endl;
990
+ return false ;
991
+ }
992
+
993
+ __nv_bfloat16 *d_A, *d_B;
994
+ float *d_C, *d_D;
995
+ __nv_bfloat16 h_A[M * K], h_B[K * N];
996
+ float h_C[M * N], h_D[M * N];
997
+ float h_D_ref[M * N];
998
+
999
+ initialize_matrices (h_A, h_B, h_C, h_D, 16 , 8 , 16 , A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1000
+
1001
+ matrix_multiplication_cpu (h_A, h_B, h_C, h_D_ref, 16 , 8 , 16 , A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1002
+
1003
+ cudaMalloc (&d_A, M * K * sizeof (__nv_bfloat16));
1004
+ cudaMalloc (&d_B, K * N * sizeof (__nv_bfloat16));
1005
+ cudaMalloc (&d_C, M * N * sizeof (float ));
1006
+ cudaMalloc (&d_D, M * N * sizeof (float ));
1007
+
1008
+ cudaMemcpy (d_A, h_A, M * K * sizeof (__nv_bfloat16), cudaMemcpyHostToDevice);
1009
+ cudaMemcpy (d_B, h_B, K * N * sizeof (__nv_bfloat16), cudaMemcpyHostToDevice);
1010
+ cudaMemcpy (d_C, h_C, M * N * sizeof (float ), cudaMemcpyHostToDevice);
1011
+ cudaMemcpy (d_D, h_D, M * N * sizeof (float ), cudaMemcpyHostToDevice);
1012
+
1013
+ int no_mat_blocks = 4 ;
1014
+ int no_blocks = CD_NUM_MAT / no_mat_blocks;
1015
+ int no_threads;
1016
+ if (no_blocks) {
1017
+ no_threads = WARP_SIZE * no_mat_blocks;
1018
+ } else {
1019
+ no_blocks = 1 ;
1020
+ no_threads = WARP_SIZE * CD_NUM_MAT;
1021
+ }
1022
+
1023
+ mma_kernel_m16n8k16_ptx_bf16_f32<16 , 8 , 16 ><<<no_blocks, no_threads>>> (
1024
+ d_A, d_B, d_C, d_D, M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1025
+ cudaDeviceSynchronize ();
1026
+ cudaMemcpy (h_D, d_D, M * N * sizeof (float ), cudaMemcpyDeviceToHost);
1027
+
1028
+ bool correct = check_result (M, N, h_D, h_D_ref);
1029
+
1030
+ std::cout << " m16n8k16 (f32.bf16.bf16.f32): " << (correct ? " PASSED" : " FAILED" )
1031
+ << std::endl;
1032
+
1033
+ cudaFree (d_A);
1034
+ cudaFree (d_B);
1035
+ cudaFree (d_C);
1036
+ cudaFree (d_D);
1037
+
1038
+ return correct;
1039
+ }
1040
+
868
1041
bool run_test_mma_m16n8k16_s8_s32 (const int M, const int N, const int K) {
869
1042
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
870
1043
calculate_num_matrices<16 , 8 , 16 >(M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
@@ -1019,6 +1192,18 @@ bool launch_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
1019
1192
return true ;
1020
1193
}
1021
1194
1195
+ bool launch_test_mma_m16n8k16_bf16_f32 (const int M, const int N, const int K) {
1196
+ bool correct = run_test_mma_m16n8k16_bf16_f32 (M, N, K);
1197
+
1198
+ if (!correct) {
1199
+ std::cerr << " m16n8k16 (f32.bf16.bf16.f32) failed for dims: " << M << " , "
1200
+ << N << " , " << K << std::endl;
1201
+ return false ;
1202
+ }
1203
+
1204
+ return true ;
1205
+ }
1206
+
1022
1207
bool launch_test_mma_m16n8k16_s8_s32 (const int M, const int N, const int K) {
1023
1208
bool correct = run_test_mma_m16n8k16_s8_s32 (M, N, K);
1024
1209
@@ -1081,6 +1266,16 @@ bool mma_m16n8k16_f16_f32() {
1081
1266
return true ;
1082
1267
}
1083
1268
1269
+ bool mma_m16n8k16_bf16_f32 () {
1270
+ LAUNCH_TEST (mma_m16n8k16_bf16_f32 (16 , 8 , 16 ));
1271
+ LAUNCH_TEST (mma_m16n8k16_bf16_f32 (32 , 16 , 32 ));
1272
+ LAUNCH_TEST (mma_m16n8k16_bf16_f32 (16 , 16 , 16 ));
1273
+ LAUNCH_TEST (mma_m16n8k16_bf16_f32 (16 , 16 , 32 ));
1274
+ LAUNCH_TEST (mma_m16n8k16_bf16_f32 (32 , 32 , 32 ));
1275
+
1276
+ return true ;
1277
+ }
1278
+
1084
1279
bool mma_m16n8k16_s8_s32 () {
1085
1280
LAUNCH_TEST (mma_m16n8k16_s8_s32 (16 , 8 , 16 ));
1086
1281
LAUNCH_TEST (mma_m16n8k16_s8_s32 (32 , 16 , 32 ));
@@ -1106,6 +1301,7 @@ int main() {
1106
1301
TEST (mma_m8n8k16_s8_s32);
1107
1302
TEST (mma_m16n8k8_f16_f32);
1108
1303
TEST (mma_m16n8k16_f16_f32);
1304
+ TEST (mma_m16n8k16_bf16_f32);
1109
1305
TEST (mma_m16n8k16_s8_s32);
1110
1306
TEST (mma_m16n8k32_s8_s32);
1111
1307
0 commit comments