Skip to content

Commit f216868

Browse files
authored
[SYCLomatic] Add E2E test for mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 (#933)
* Add E2E test for mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32
1 parent d3f3418 commit f216868

File tree

1 file changed

+196
-0
lines changed

1 file changed

+196
-0
lines changed

features/feature_case/asm/asm_mma.cu

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cuda_fp16.h>
1111
#include <cuda_bf16.h>
1212
#include <iostream>
13+
#include <stdint.h>
1314

1415
#define TEST(FN) \
1516
{ \
@@ -478,6 +479,119 @@ __global__ void mma_kernel_m16n8k16_ptx_f16_f32(half *A, half *B, float *C, floa
478479
}
479480
}
480481

482+
483+
template <int Shape_M, int Shape_N, int Shape_K>
484+
__global__ void mma_kernel_m16n8k16_ptx_bf16_f32(__nv_bfloat16 *A, __nv_bfloat16 *B, float *C, float *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
485+
const int THREAD_IDX = threadIdx.x + blockIdx.x * blockDim.x;
486+
const int WARP_ID = THREAD_IDX / WARP_SIZE;
487+
const int LANE_ID = THREAD_IDX % WARP_SIZE;
488+
489+
const int THREAD_ROW = LANE_ID / 4;
490+
const int THREAD_COL = LANE_ID % 4;
491+
492+
int A_OFFSET = (WARP_ID % A_NUM_MAT) * Shape_M * Shape_K;
493+
int B_OFFSET = (WARP_ID % B_NUM_MAT) * Shape_K * Shape_N;
494+
int CD_OFFSET = (WARP_ID % CD_NUM_MAT) * Shape_M * Shape_N;
495+
496+
uint32_t a[4];
497+
uint32_t b[2];
498+
float c[4];
499+
float d[4];
500+
501+
unsigned char indx = 0;
502+
for (unsigned char i = 0; i < 8; i++) {
503+
int r_off = 8;
504+
if (i < 2 || (i >= 4 && i < 6)) {
505+
r_off = 0;
506+
}
507+
508+
int c_off = 0;
509+
if (i >= 4) {
510+
c_off = 8;
511+
}
512+
513+
int load_offset =
514+
A_OFFSET + OFFSET(THREAD_ROW + r_off,
515+
(THREAD_COL * 2) + (i & 0x1) + c_off, Shape_K);
516+
if (IN_BOUND_A(load_offset)) {
517+
__nv_bfloat16 val = A[load_offset];
518+
if ((i & 0x01) == 0) {
519+
// First value of pair, put to high
520+
a[indx] = uint32_t(*reinterpret_cast<uint16_t *>(&val) << 16);
521+
} else {
522+
// Second value of pair, put to low
523+
a[indx] = a[indx] | *reinterpret_cast<uint16_t *>(&val);
524+
indx++;
525+
}
526+
}
527+
}
528+
529+
indx = 0;
530+
for (int i = 0; i < 4; i++) {
531+
int r_off = 0;
532+
if (i >= 2) {
533+
r_off = 8;
534+
}
535+
536+
int load_offset = B_OFFSET + OFFSET((THREAD_COL * 2) + (i & 0x1) + r_off,
537+
THREAD_ROW, Shape_N);
538+
if (IN_BOUND_B(load_offset)) {
539+
__nv_bfloat16 val = B[load_offset];
540+
if ((i & 0x01) == 0) {
541+
// First value of pair, put to high
542+
b[indx] = uint32_t(*reinterpret_cast<uint16_t *>(&val) << 16);
543+
} else {
544+
// Second value of pair, put to low
545+
b[indx] = b[indx] | *reinterpret_cast<uint16_t *>(&val);
546+
indx++;
547+
}
548+
}
549+
}
550+
551+
for (int i = 0; i < 4; i++) {
552+
int load_offset;
553+
554+
if (i < 2) {
555+
load_offset =
556+
CD_OFFSET + OFFSET(THREAD_ROW, (THREAD_COL * 2) + (i & 0x1), Shape_N);
557+
} else {
558+
load_offset = CD_OFFSET + OFFSET(THREAD_ROW + 8,
559+
(THREAD_COL * 2) + (i & 0x1), Shape_N);
560+
}
561+
562+
if (IN_BOUND_CD(load_offset)) {
563+
c[i] = C[load_offset];
564+
}
565+
}
566+
567+
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
568+
" { %0, %1, %2, %3 }, "
569+
" { %4, %5, %6, %7 }, "
570+
" { %8, %9 }, "
571+
" { %10, %11, %12, %13 };"
572+
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
573+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
574+
"r"(b[0]), "r"(b[1]),
575+
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
576+
577+
578+
for (int i = 0; i < 4; i++) {
579+
int load_offset;
580+
581+
if (i < 2) {
582+
load_offset =
583+
CD_OFFSET + OFFSET(THREAD_ROW, (THREAD_COL * 2) + (i & 0x1), Shape_N);
584+
} else {
585+
load_offset = CD_OFFSET + OFFSET(THREAD_ROW + 8,
586+
(THREAD_COL * 2) + (i & 0x1), Shape_N);
587+
}
588+
589+
if (IN_BOUND_CD(load_offset)) {
590+
D[load_offset] = d[i];
591+
}
592+
}
593+
}
594+
481595
template<int Shape_M, int Shape_N, int Shape_K>
482596
__global__ void mma_kernel_m16n8k16_s8_s32(int8_t *A, int8_t *B, int *C, int *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
483597
const int THREAD_IDX = threadIdx.x + blockIdx.x * blockDim.x;
@@ -865,6 +979,65 @@ bool run_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
865979
return correct;
866980
}
867981

982+
983+
bool run_test_mma_m16n8k16_bf16_f32(const int M, const int N, const int K) {
984+
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
985+
calculate_num_matrices<16, 8, 16>(M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
986+
987+
if (A_NUM_MAT == 0 || B_NUM_MAT == 0 || CD_NUM_MAT == 0) {
988+
std::cerr << "Matrix dimensions are not compatible with m16n8k16"
989+
<< std::endl;
990+
return false;
991+
}
992+
993+
__nv_bfloat16 *d_A, *d_B;
994+
float *d_C, *d_D;
995+
__nv_bfloat16 h_A[M * K], h_B[K * N];
996+
float h_C[M * N], h_D[M * N];
997+
float h_D_ref[M * N];
998+
999+
initialize_matrices(h_A, h_B, h_C, h_D, 16, 8, 16, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1000+
1001+
matrix_multiplication_cpu(h_A, h_B, h_C, h_D_ref, 16, 8, 16, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1002+
1003+
cudaMalloc(&d_A, M * K * sizeof(__nv_bfloat16));
1004+
cudaMalloc(&d_B, K * N * sizeof(__nv_bfloat16));
1005+
cudaMalloc(&d_C, M * N * sizeof(float));
1006+
cudaMalloc(&d_D, M * N * sizeof(float));
1007+
1008+
cudaMemcpy(d_A, h_A, M * K * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice);
1009+
cudaMemcpy(d_B, h_B, K * N * sizeof(__nv_bfloat16), cudaMemcpyHostToDevice);
1010+
cudaMemcpy(d_C, h_C, M * N * sizeof(float), cudaMemcpyHostToDevice);
1011+
cudaMemcpy(d_D, h_D, M * N * sizeof(float), cudaMemcpyHostToDevice);
1012+
1013+
int no_mat_blocks = 4;
1014+
int no_blocks = CD_NUM_MAT / no_mat_blocks;
1015+
int no_threads;
1016+
if (no_blocks) {
1017+
no_threads = WARP_SIZE * no_mat_blocks;
1018+
} else {
1019+
no_blocks = 1;
1020+
no_threads = WARP_SIZE * CD_NUM_MAT;
1021+
}
1022+
1023+
mma_kernel_m16n8k16_ptx_bf16_f32<16, 8, 16><<<no_blocks, no_threads>>>(
1024+
d_A, d_B, d_C, d_D, M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1025+
cudaDeviceSynchronize();
1026+
cudaMemcpy(h_D, d_D, M * N * sizeof(float), cudaMemcpyDeviceToHost);
1027+
1028+
bool correct = check_result(M, N, h_D, h_D_ref);
1029+
1030+
std::cout << "m16n8k16 (f32.bf16.bf16.f32): " << (correct ? "PASSED" : "FAILED")
1031+
<< std::endl;
1032+
1033+
cudaFree(d_A);
1034+
cudaFree(d_B);
1035+
cudaFree(d_C);
1036+
cudaFree(d_D);
1037+
1038+
return correct;
1039+
}
1040+
8681041
bool run_test_mma_m16n8k16_s8_s32(const int M, const int N, const int K) {
8691042
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
8701043
calculate_num_matrices<16, 8, 16>(M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
@@ -1019,6 +1192,18 @@ bool launch_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
10191192
return true;
10201193
}
10211194

1195+
bool launch_test_mma_m16n8k16_bf16_f32(const int M, const int N, const int K) {
1196+
bool correct = run_test_mma_m16n8k16_bf16_f32(M, N, K);
1197+
1198+
if (!correct) {
1199+
std::cerr << "m16n8k16 (f32.bf16.bf16.f32) failed for dims: " << M << ", "
1200+
<< N << ", " << K << std::endl;
1201+
return false;
1202+
}
1203+
1204+
return true;
1205+
}
1206+
10221207
bool launch_test_mma_m16n8k16_s8_s32(const int M, const int N, const int K) {
10231208
bool correct = run_test_mma_m16n8k16_s8_s32(M, N, K);
10241209

@@ -1081,6 +1266,16 @@ bool mma_m16n8k16_f16_f32() {
10811266
return true;
10821267
}
10831268

1269+
bool mma_m16n8k16_bf16_f32() {
1270+
LAUNCH_TEST(mma_m16n8k16_bf16_f32(16, 8, 16));
1271+
LAUNCH_TEST(mma_m16n8k16_bf16_f32(32, 16, 32));
1272+
LAUNCH_TEST(mma_m16n8k16_bf16_f32(16, 16, 16));
1273+
LAUNCH_TEST(mma_m16n8k16_bf16_f32(16, 16, 32));
1274+
LAUNCH_TEST(mma_m16n8k16_bf16_f32(32, 32, 32));
1275+
1276+
return true;
1277+
}
1278+
10841279
bool mma_m16n8k16_s8_s32() {
10851280
LAUNCH_TEST(mma_m16n8k16_s8_s32(16, 8, 16));
10861281
LAUNCH_TEST(mma_m16n8k16_s8_s32(32, 16, 32));
@@ -1106,6 +1301,7 @@ int main() {
11061301
TEST(mma_m8n8k16_s8_s32);
11071302
TEST(mma_m16n8k8_f16_f32);
11081303
TEST(mma_m16n8k16_f16_f32);
1304+
TEST(mma_m16n8k16_bf16_f32);
11091305
TEST(mma_m16n8k16_s8_s32);
11101306
TEST(mma_m16n8k32_s8_s32);
11111307

0 commit comments

Comments
 (0)