Skip to content

Commit c7330ad

Browse files
authored
Add test for mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 (#935)
Add test for mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16
1 parent f216868 commit c7330ad

File tree

1 file changed

+258
-1
lines changed

1 file changed

+258
-1
lines changed

features/feature_case/asm/asm_mma.cu

Lines changed: 258 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,42 @@ __host__ void initialize_matrices(int8_t *A, int8_t *B, int *C, int *D, int M, i
9696
}
9797

9898

99+
__host__ void initialize_matrices(half *A, half *B, half *C, half *D, int M,
100+
int N, int K, int A_NUM_MAT, int B_NUM_MAT,
101+
int CD_NUM_MAT) {
102+
for (int N_MAT = 0; N_MAT < A_NUM_MAT; N_MAT++) {
103+
int A_OFFSET = N_MAT * M * K;
104+
105+
for (int i = 0; i < M; ++i) {
106+
for (int j = 0; j < K; ++j) {
107+
A[A_OFFSET + i * K + j] = __float2half(((i * K + j) / 4.0f));
108+
}
109+
}
110+
}
111+
112+
for (int N_MAT = 0; N_MAT < B_NUM_MAT; N_MAT++) {
113+
int B_OFFSET = N_MAT * K * N;
114+
115+
for (int i = 0; i < K; ++i) {
116+
for (int j = 0; j < N; ++j) {
117+
B[B_OFFSET + i * N + j] = __float2half(((i * N + j) / 4.0f));
118+
}
119+
}
120+
}
121+
122+
for (int N_MAT = 0; N_MAT < CD_NUM_MAT; N_MAT++) {
123+
int CD_OFFSET = N_MAT * M * N;
124+
125+
for (int i = 0; i < M; ++i) {
126+
for (int j = 0; j < N; ++j) {
127+
C[CD_OFFSET + i * N + j] = __float2half((i * N + j) * 1.0f);
128+
D[CD_OFFSET + i * N + j] = 0.0f;
129+
}
130+
}
131+
}
132+
}
133+
134+
99135
template <typename ABType, typename CDType>
100136
void matrix_multiplication_cpu(ABType *A, ABType *B, CDType *C, CDType *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
101137
for (int N_MAT = 0; N_MAT < CD_NUM_MAT; N_MAT++) {
@@ -114,7 +150,44 @@ void matrix_multiplication_cpu(ABType *A, ABType *B, CDType *C, CDType *D, int M
114150
}
115151
}
116152
}
153+
/*
154+
1. Properties of half (float16)
155+
Limited precision: float16 has only 5 bits for the exponent and 10 bits for the
156+
mantissa, giving roughly 3 decimal digits of precision.
157+
158+
Non-uniform precision: The closer to zero, the higher the precision; the closer
159+
to the maximum value, the lower the precision.
160+
161+
2. Difference of CPU vs. GPU on half support
162+
* On the GPU (e.g., Tensor Core or native FP16 ALUs):
163+
Many GPUs, especially NVIDIA ones, support native half x half operations,
164+
either:
165+
166+
half x half -> half or half x half -> float
167+
168+
Using Tensor Cores or PTX instructions like mma.sync, often:
169+
170+
half x half -> float (accumulate) -> rounded to half at the end
171+
172+
So the GPU may retain intermediate precision in float, which can differ from a
173+
float-roundtrip on the CPU.
174+
175+
* On the CPU:
176+
CPUs usually don't have native half-precision units, so the operation goes
177+
like:
117178
179+
Convert half to float -> multiply as float -> convert result back to half
180+
181+
This involves two rounding steps:
182+
183+
From half to float (no loss)
184+
185+
From float to half (can introduce significant rounding errors)
186+
187+
For example, if you multiply 60000 x 2 as float, you get 120000. But float16
188+
can't represent 120000, so it becomes inf.
189+
*/
190+
// Note: The results of cpu vs. gpu is different, especially for larger number.
118191
void matrix_multiplication_cpu(half *A, half *B, half *C, half *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
119192
for (int N_MAT = 0; N_MAT < CD_NUM_MAT; N_MAT++) {
120193
int A_OFFSET = (N_MAT % A_NUM_MAT) * (M * K);
@@ -480,6 +553,114 @@ __global__ void mma_kernel_m16n8k16_ptx_f16_f32(half *A, half *B, float *C, floa
480553
}
481554

482555

556+
template <int Shape_M, int Shape_N, int Shape_K>
557+
__global__ void mma_kernel_m16n8k16_ptx_f16_f16(half *A, half *B, half *C,
558+
half *D, int M, int N, int K,
559+
int A_NUM_MAT, int B_NUM_MAT,
560+
int CD_NUM_MAT) {
561+
const int THREAD_IDX = threadIdx.x + blockIdx.x * blockDim.x;
562+
const int WARP_ID = THREAD_IDX / WARP_SIZE;
563+
const int LANE_ID = THREAD_IDX % WARP_SIZE;
564+
565+
const int THREAD_ROW = LANE_ID / 4;
566+
const int THREAD_COL = LANE_ID % 4;
567+
568+
int A_OFFSET = (WARP_ID % A_NUM_MAT) * Shape_M * Shape_K;
569+
int B_OFFSET = (WARP_ID % B_NUM_MAT) * Shape_K * Shape_N;
570+
int CD_OFFSET = (WARP_ID % CD_NUM_MAT) * Shape_M * Shape_N;
571+
572+
uint32_t a[4];
573+
uint32_t b[2];
574+
uint32_t c[2];
575+
uint32_t d[2];
576+
577+
auto ra = reinterpret_cast<half *>(a);
578+
auto rb = reinterpret_cast<half *>(b);
579+
auto rc = reinterpret_cast<half *>(c);
580+
auto rd = reinterpret_cast<half *>(d);
581+
582+
for (int i = 0; i < 8; i++) {
583+
int r_off = 8;
584+
if (i < 2 || (i >= 4 && i < 6)) {
585+
r_off = 0;
586+
}
587+
588+
int c_off = 0;
589+
if (i >= 4) {
590+
c_off = 8;
591+
}
592+
593+
int load_offset =
594+
A_OFFSET + OFFSET(THREAD_ROW + r_off,
595+
(THREAD_COL * 2) + (i & 0x1) + c_off, Shape_K);
596+
if (IN_BOUND_A(load_offset)) {
597+
ra[i] = A[load_offset];
598+
599+
}
600+
}
601+
602+
for (int i = 0; i < 4; i++) {
603+
int r_off = 0;
604+
if (i >= 2) {
605+
r_off = 8;
606+
}
607+
608+
int load_offset = B_OFFSET + OFFSET((THREAD_COL * 2) + (i & 0x1) + r_off,
609+
THREAD_ROW, Shape_N);
610+
if (IN_BOUND_B(load_offset)) {
611+
rb[i] = B[load_offset];
612+
}
613+
}
614+
615+
for (int i = 0; i < 4; i++) {
616+
int load_offset;
617+
618+
if (i < 2) {
619+
load_offset =
620+
CD_OFFSET + OFFSET(THREAD_ROW, (THREAD_COL * 2) + (i & 0x1), Shape_N);
621+
} else {
622+
load_offset = CD_OFFSET + OFFSET(THREAD_ROW + 8,
623+
(THREAD_COL * 2) + (i & 0x1), Shape_N);
624+
}
625+
626+
if (IN_BOUND_CD(load_offset)) {
627+
rc[i] = C[load_offset];
628+
}
629+
}
630+
631+
632+
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
633+
" { %0, %1}, "
634+
" { %2, %3, %4, %5 }, "
635+
" { %6, %7 }, "
636+
" { %8, %9 };"
637+
: "+r"(d[0]), "+r"(d[1])
638+
: "r"(*(reinterpret_cast<int *>(&a[0]))),
639+
"r"(*(reinterpret_cast<int *>(&a[1]))),
640+
"r"(*(reinterpret_cast<int *>(&a[2]))),
641+
"r"(*(reinterpret_cast<int *>(&a[3]))),
642+
"r"(*(reinterpret_cast<int *>(&b[0]))),
643+
"r"(*(reinterpret_cast<int *>(&b[1]))),
644+
"r"(c[0]), "r"(c[1]));
645+
646+
for (int i = 0; i < 4; i++) {
647+
int load_offset;
648+
649+
if (i < 2) {
650+
load_offset =
651+
CD_OFFSET + OFFSET(THREAD_ROW, (THREAD_COL * 2) + (i & 0x1), Shape_N);
652+
} else {
653+
load_offset = CD_OFFSET + OFFSET(THREAD_ROW + 8,
654+
(THREAD_COL * 2) + (i & 0x1), Shape_N);
655+
}
656+
657+
if (IN_BOUND_CD(load_offset)) {
658+
D[load_offset] = rd[i];
659+
}
660+
}
661+
}
662+
663+
483664
template <int Shape_M, int Shape_N, int Shape_K>
484665
__global__ void mma_kernel_m16n8k16_ptx_bf16_f32(__nv_bfloat16 *A, __nv_bfloat16 *B, float *C, float *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
485666
const int THREAD_IDX = threadIdx.x + blockIdx.x * blockDim.x;
@@ -979,6 +1160,60 @@ bool run_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
9791160
return correct;
9801161
}
9811162

1163+
bool run_test_mma_m16n8k16_f16_f16(const int M, const int N, const int K) {
1164+
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
1165+
calculate_num_matrices<16, 8, 16>(M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1166+
1167+
if (A_NUM_MAT == 0 || B_NUM_MAT == 0 || CD_NUM_MAT == 0) {
1168+
std::cerr << "Matrix dimensions are not compatible with m16n8k16" << std::endl;
1169+
return false;
1170+
}
1171+
1172+
half *d_A, *d_B;
1173+
half *d_C, *d_D;
1174+
half h_A[M * K], h_B[K * N];
1175+
half h_C[M * N], h_D[M * N], h_D_ref[M * N];
1176+
1177+
initialize_matrices(h_A, h_B, h_C, h_D, 16, 8, 16, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1178+
1179+
matrix_multiplication_cpu(h_A, h_B, h_C, h_D_ref, 16, 8, 16, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1180+
1181+
cudaMalloc(&d_A, M * K * sizeof(half));
1182+
cudaMalloc(&d_B, K * N * sizeof(half));
1183+
cudaMalloc(&d_C, M * N * sizeof(half));
1184+
cudaMalloc(&d_D, M * N * sizeof(half));
1185+
1186+
cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice);
1187+
cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice);
1188+
cudaMemcpy(d_C, h_C, M * N * sizeof(half), cudaMemcpyHostToDevice);
1189+
cudaMemcpy(d_D, h_D, M * N * sizeof(half), cudaMemcpyHostToDevice);
1190+
1191+
int no_mat_blocks = 4;
1192+
int no_blocks = CD_NUM_MAT / no_mat_blocks;
1193+
int no_threads;
1194+
if (no_blocks) {
1195+
no_threads = WARP_SIZE * no_mat_blocks;
1196+
} else {
1197+
no_blocks = 1;
1198+
no_threads = WARP_SIZE * CD_NUM_MAT;
1199+
}
1200+
1201+
mma_kernel_m16n8k16_ptx_f16_f16<16, 8, 16><<<no_blocks, no_threads>>>(d_A, d_B, d_C, d_D, M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1202+
cudaDeviceSynchronize();
1203+
cudaMemcpy(h_D, d_D, M * N * sizeof(half), cudaMemcpyDeviceToHost);
1204+
1205+
bool correct = check_result(M, N, h_D, h_D_ref);
1206+
1207+
std::cout << "m16n8k16 (f16.f16.f16.f16): " << (correct ? "PASSED" : "FAILED") << std::endl;
1208+
1209+
cudaFree(d_A);
1210+
cudaFree(d_B);
1211+
cudaFree(d_C);
1212+
cudaFree(d_D);
1213+
1214+
return correct;
1215+
}
1216+
9821217

9831218
bool run_test_mma_m16n8k16_bf16_f32(const int M, const int N, const int K) {
9841219
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
@@ -1181,6 +1416,18 @@ bool launch_test_mma_m16n8k8_f16_f32(const int M, const int N, const int K) {
11811416
return true;
11821417
}
11831418

1419+
bool launch_test_mma_m16n8k16_f16_f16(const int M, const int N, const int K) {
1420+
bool correct = run_test_mma_m16n8k16_f16_f16(M, N, K);
1421+
1422+
if (!correct) {
1423+
std::cerr << "m16n8k16 (f16.f16.f16.f16) failed for dims: " << M << ", " << N
1424+
<< ", " << K << std::endl;
1425+
return false;
1426+
}
1427+
1428+
return true;
1429+
}
1430+
11841431
bool launch_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
11851432
bool correct = run_test_mma_m16n8k16_f16_f32(M, N, K);
11861433

@@ -1266,6 +1513,16 @@ bool mma_m16n8k16_f16_f32() {
12661513
return true;
12671514
}
12681515

1516+
bool mma_m16n8k16_f16_f16() {
1517+
LAUNCH_TEST(mma_m16n8k16_f16_f16(16, 8, 16));
1518+
LAUNCH_TEST(mma_m16n8k16_f16_f16(32, 16, 32));
1519+
LAUNCH_TEST(mma_m16n8k16_f16_f16(16, 16, 16));
1520+
LAUNCH_TEST(mma_m16n8k16_f16_f16(16, 16, 32));
1521+
LAUNCH_TEST(mma_m16n8k16_f16_f16(32, 32, 32));
1522+
1523+
return true;
1524+
}
1525+
12691526
bool mma_m16n8k16_bf16_f32() {
12701527
LAUNCH_TEST(mma_m16n8k16_bf16_f32(16, 8, 16));
12711528
LAUNCH_TEST(mma_m16n8k16_bf16_f32(32, 16, 32));
@@ -1304,6 +1561,6 @@ int main() {
13041561
TEST(mma_m16n8k16_bf16_f32);
13051562
TEST(mma_m16n8k16_s8_s32);
13061563
TEST(mma_m16n8k32_s8_s32);
1307-
1564+
TEST(mma_m16n8k16_f16_f16);
13081565
return 0;
13091566
}

0 commit comments

Comments
 (0)