@@ -96,6 +96,42 @@ __host__ void initialize_matrices(int8_t *A, int8_t *B, int *C, int *D, int M, i
96
96
}
97
97
98
98
99
+ __host__ void initialize_matrices (half *A, half *B, half *C, half *D, int M,
100
+ int N, int K, int A_NUM_MAT, int B_NUM_MAT,
101
+ int CD_NUM_MAT) {
102
+ for (int N_MAT = 0 ; N_MAT < A_NUM_MAT; N_MAT++) {
103
+ int A_OFFSET = N_MAT * M * K;
104
+
105
+ for (int i = 0 ; i < M; ++i) {
106
+ for (int j = 0 ; j < K; ++j) {
107
+ A[A_OFFSET + i * K + j] = __float2half (((i * K + j) / 4 .0f ));
108
+ }
109
+ }
110
+ }
111
+
112
+ for (int N_MAT = 0 ; N_MAT < B_NUM_MAT; N_MAT++) {
113
+ int B_OFFSET = N_MAT * K * N;
114
+
115
+ for (int i = 0 ; i < K; ++i) {
116
+ for (int j = 0 ; j < N; ++j) {
117
+ B[B_OFFSET + i * N + j] = __float2half (((i * N + j) / 4 .0f ));
118
+ }
119
+ }
120
+ }
121
+
122
+ for (int N_MAT = 0 ; N_MAT < CD_NUM_MAT; N_MAT++) {
123
+ int CD_OFFSET = N_MAT * M * N;
124
+
125
+ for (int i = 0 ; i < M; ++i) {
126
+ for (int j = 0 ; j < N; ++j) {
127
+ C[CD_OFFSET + i * N + j] = __float2half ((i * N + j) * 1 .0f );
128
+ D[CD_OFFSET + i * N + j] = 0 .0f ;
129
+ }
130
+ }
131
+ }
132
+ }
133
+
134
+
99
135
template <typename ABType, typename CDType>
100
136
void matrix_multiplication_cpu (ABType *A, ABType *B, CDType *C, CDType *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
101
137
for (int N_MAT = 0 ; N_MAT < CD_NUM_MAT; N_MAT++) {
@@ -114,7 +150,44 @@ void matrix_multiplication_cpu(ABType *A, ABType *B, CDType *C, CDType *D, int M
114
150
}
115
151
}
116
152
}
153
+ /*
154
+ 1. Properties of half (float16)
155
+ Limited precision: float16 has only 5 bits for the exponent and 10 bits for the
156
+ mantissa, giving roughly 3 decimal digits of precision.
157
+
158
+ Non-uniform precision: The closer to zero, the higher the precision; the closer
159
+ to the maximum value, the lower the precision.
160
+
161
+ 2. Difference of CPU vs. GPU on half support
162
+ * On the GPU (e.g., Tensor Core or native FP16 ALUs):
163
+ Many GPUs, especially NVIDIA ones, support native half x half operations,
164
+ either:
165
+
166
+ half x half -> half or half x half -> float
167
+
168
+ Using Tensor Cores or PTX instructions like mma.sync, often:
169
+
170
+ half x half -> float (accumulate) -> rounded to half at the end
171
+
172
+ So the GPU may retain intermediate precision in float, which can differ from a
173
+ float-roundtrip on the CPU.
174
+
175
+ * On the CPU:
176
+ CPUs usually don't have native half-precision units, so the operation goes
177
+ like:
117
178
179
+ Convert half to float -> multiply as float -> convert result back to half
180
+
181
+ This involves two rounding steps:
182
+
183
+ From half to float (no loss)
184
+
185
+ From float to half (can introduce significant rounding errors)
186
+
187
+ For example, if you multiply 60000 x 2 as float, you get 120000. But float16
188
+ can't represent 120000, so it becomes inf.
189
+ */
190
+ // Note: The results of cpu vs. gpu is different, especially for larger number.
118
191
void matrix_multiplication_cpu (half *A, half *B, half *C, half *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
119
192
for (int N_MAT = 0 ; N_MAT < CD_NUM_MAT; N_MAT++) {
120
193
int A_OFFSET = (N_MAT % A_NUM_MAT) * (M * K);
@@ -480,6 +553,114 @@ __global__ void mma_kernel_m16n8k16_ptx_f16_f32(half *A, half *B, float *C, floa
480
553
}
481
554
482
555
556
+ template <int Shape_M, int Shape_N, int Shape_K>
557
+ __global__ void mma_kernel_m16n8k16_ptx_f16_f16 (half *A, half *B, half *C,
558
+ half *D, int M, int N, int K,
559
+ int A_NUM_MAT, int B_NUM_MAT,
560
+ int CD_NUM_MAT) {
561
+ const int THREAD_IDX = threadIdx .x + blockIdx .x * blockDim .x ;
562
+ const int WARP_ID = THREAD_IDX / WARP_SIZE;
563
+ const int LANE_ID = THREAD_IDX % WARP_SIZE;
564
+
565
+ const int THREAD_ROW = LANE_ID / 4 ;
566
+ const int THREAD_COL = LANE_ID % 4 ;
567
+
568
+ int A_OFFSET = (WARP_ID % A_NUM_MAT) * Shape_M * Shape_K;
569
+ int B_OFFSET = (WARP_ID % B_NUM_MAT) * Shape_K * Shape_N;
570
+ int CD_OFFSET = (WARP_ID % CD_NUM_MAT) * Shape_M * Shape_N;
571
+
572
+ uint32_t a[4 ];
573
+ uint32_t b[2 ];
574
+ uint32_t c[2 ];
575
+ uint32_t d[2 ];
576
+
577
+ auto ra = reinterpret_cast <half *>(a);
578
+ auto rb = reinterpret_cast <half *>(b);
579
+ auto rc = reinterpret_cast <half *>(c);
580
+ auto rd = reinterpret_cast <half *>(d);
581
+
582
+ for (int i = 0 ; i < 8 ; i++) {
583
+ int r_off = 8 ;
584
+ if (i < 2 || (i >= 4 && i < 6 )) {
585
+ r_off = 0 ;
586
+ }
587
+
588
+ int c_off = 0 ;
589
+ if (i >= 4 ) {
590
+ c_off = 8 ;
591
+ }
592
+
593
+ int load_offset =
594
+ A_OFFSET + OFFSET (THREAD_ROW + r_off,
595
+ (THREAD_COL * 2 ) + (i & 0x1 ) + c_off, Shape_K);
596
+ if (IN_BOUND_A (load_offset)) {
597
+ ra[i] = A[load_offset];
598
+
599
+ }
600
+ }
601
+
602
+ for (int i = 0 ; i < 4 ; i++) {
603
+ int r_off = 0 ;
604
+ if (i >= 2 ) {
605
+ r_off = 8 ;
606
+ }
607
+
608
+ int load_offset = B_OFFSET + OFFSET ((THREAD_COL * 2 ) + (i & 0x1 ) + r_off,
609
+ THREAD_ROW, Shape_N);
610
+ if (IN_BOUND_B (load_offset)) {
611
+ rb[i] = B[load_offset];
612
+ }
613
+ }
614
+
615
+ for (int i = 0 ; i < 4 ; i++) {
616
+ int load_offset;
617
+
618
+ if (i < 2 ) {
619
+ load_offset =
620
+ CD_OFFSET + OFFSET (THREAD_ROW, (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
621
+ } else {
622
+ load_offset = CD_OFFSET + OFFSET (THREAD_ROW + 8 ,
623
+ (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
624
+ }
625
+
626
+ if (IN_BOUND_CD (load_offset)) {
627
+ rc[i] = C[load_offset];
628
+ }
629
+ }
630
+
631
+
632
+ asm (" mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
633
+ " { %0, %1}, "
634
+ " { %2, %3, %4, %5 }, "
635
+ " { %6, %7 }, "
636
+ " { %8, %9 };"
637
+ : " +r" (d[0 ]), " +r" (d[1 ])
638
+ : " r" (*(reinterpret_cast <int *>(&a[0 ]))),
639
+ " r" (*(reinterpret_cast <int *>(&a[1 ]))),
640
+ " r" (*(reinterpret_cast <int *>(&a[2 ]))),
641
+ " r" (*(reinterpret_cast <int *>(&a[3 ]))),
642
+ " r" (*(reinterpret_cast <int *>(&b[0 ]))),
643
+ " r" (*(reinterpret_cast <int *>(&b[1 ]))),
644
+ " r" (c[0 ]), " r" (c[1 ]));
645
+
646
+ for (int i = 0 ; i < 4 ; i++) {
647
+ int load_offset;
648
+
649
+ if (i < 2 ) {
650
+ load_offset =
651
+ CD_OFFSET + OFFSET (THREAD_ROW, (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
652
+ } else {
653
+ load_offset = CD_OFFSET + OFFSET (THREAD_ROW + 8 ,
654
+ (THREAD_COL * 2 ) + (i & 0x1 ), Shape_N);
655
+ }
656
+
657
+ if (IN_BOUND_CD (load_offset)) {
658
+ D[load_offset] = rd[i];
659
+ }
660
+ }
661
+ }
662
+
663
+
483
664
template <int Shape_M, int Shape_N, int Shape_K>
484
665
__global__ void mma_kernel_m16n8k16_ptx_bf16_f32 (__nv_bfloat16 *A, __nv_bfloat16 *B, float *C, float *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
485
666
const int THREAD_IDX = threadIdx .x + blockIdx .x * blockDim .x ;
@@ -979,6 +1160,60 @@ bool run_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
979
1160
return correct;
980
1161
}
981
1162
1163
+ bool run_test_mma_m16n8k16_f16_f16 (const int M, const int N, const int K) {
1164
+ int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
1165
+ calculate_num_matrices<16 , 8 , 16 >(M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1166
+
1167
+ if (A_NUM_MAT == 0 || B_NUM_MAT == 0 || CD_NUM_MAT == 0 ) {
1168
+ std::cerr << " Matrix dimensions are not compatible with m16n8k16" << std::endl;
1169
+ return false ;
1170
+ }
1171
+
1172
+ half *d_A, *d_B;
1173
+ half *d_C, *d_D;
1174
+ half h_A[M * K], h_B[K * N];
1175
+ half h_C[M * N], h_D[M * N], h_D_ref[M * N];
1176
+
1177
+ initialize_matrices (h_A, h_B, h_C, h_D, 16 , 8 , 16 , A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1178
+
1179
+ matrix_multiplication_cpu (h_A, h_B, h_C, h_D_ref, 16 , 8 , 16 , A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1180
+
1181
+ cudaMalloc (&d_A, M * K * sizeof (half));
1182
+ cudaMalloc (&d_B, K * N * sizeof (half));
1183
+ cudaMalloc (&d_C, M * N * sizeof (half));
1184
+ cudaMalloc (&d_D, M * N * sizeof (half));
1185
+
1186
+ cudaMemcpy (d_A, h_A, M * K * sizeof (half), cudaMemcpyHostToDevice);
1187
+ cudaMemcpy (d_B, h_B, K * N * sizeof (half), cudaMemcpyHostToDevice);
1188
+ cudaMemcpy (d_C, h_C, M * N * sizeof (half), cudaMemcpyHostToDevice);
1189
+ cudaMemcpy (d_D, h_D, M * N * sizeof (half), cudaMemcpyHostToDevice);
1190
+
1191
+ int no_mat_blocks = 4 ;
1192
+ int no_blocks = CD_NUM_MAT / no_mat_blocks;
1193
+ int no_threads;
1194
+ if (no_blocks) {
1195
+ no_threads = WARP_SIZE * no_mat_blocks;
1196
+ } else {
1197
+ no_blocks = 1 ;
1198
+ no_threads = WARP_SIZE * CD_NUM_MAT;
1199
+ }
1200
+
1201
+ mma_kernel_m16n8k16_ptx_f16_f16<16 , 8 , 16 ><<<no_blocks, no_threads>>> (d_A, d_B, d_C, d_D, M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
1202
+ cudaDeviceSynchronize ();
1203
+ cudaMemcpy (h_D, d_D, M * N * sizeof (half), cudaMemcpyDeviceToHost);
1204
+
1205
+ bool correct = check_result (M, N, h_D, h_D_ref);
1206
+
1207
+ std::cout << " m16n8k16 (f16.f16.f16.f16): " << (correct ? " PASSED" : " FAILED" ) << std::endl;
1208
+
1209
+ cudaFree (d_A);
1210
+ cudaFree (d_B);
1211
+ cudaFree (d_C);
1212
+ cudaFree (d_D);
1213
+
1214
+ return correct;
1215
+ }
1216
+
982
1217
983
1218
bool run_test_mma_m16n8k16_bf16_f32 (const int M, const int N, const int K) {
984
1219
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
@@ -1181,6 +1416,18 @@ bool launch_test_mma_m16n8k8_f16_f32(const int M, const int N, const int K) {
1181
1416
return true ;
1182
1417
}
1183
1418
1419
+ bool launch_test_mma_m16n8k16_f16_f16 (const int M, const int N, const int K) {
1420
+ bool correct = run_test_mma_m16n8k16_f16_f16 (M, N, K);
1421
+
1422
+ if (!correct) {
1423
+ std::cerr << " m16n8k16 (f16.f16.f16.f16) failed for dims: " << M << " , " << N
1424
+ << " , " << K << std::endl;
1425
+ return false ;
1426
+ }
1427
+
1428
+ return true ;
1429
+ }
1430
+
1184
1431
bool launch_test_mma_m16n8k16_f16_f32 (const int M, const int N, const int K) {
1185
1432
bool correct = run_test_mma_m16n8k16_f16_f32 (M, N, K);
1186
1433
@@ -1266,6 +1513,16 @@ bool mma_m16n8k16_f16_f32() {
1266
1513
return true ;
1267
1514
}
1268
1515
1516
+ bool mma_m16n8k16_f16_f16 () {
1517
+ LAUNCH_TEST (mma_m16n8k16_f16_f16 (16 , 8 , 16 ));
1518
+ LAUNCH_TEST (mma_m16n8k16_f16_f16 (32 , 16 , 32 ));
1519
+ LAUNCH_TEST (mma_m16n8k16_f16_f16 (16 , 16 , 16 ));
1520
+ LAUNCH_TEST (mma_m16n8k16_f16_f16 (16 , 16 , 32 ));
1521
+ LAUNCH_TEST (mma_m16n8k16_f16_f16 (32 , 32 , 32 ));
1522
+
1523
+ return true ;
1524
+ }
1525
+
1269
1526
bool mma_m16n8k16_bf16_f32 () {
1270
1527
LAUNCH_TEST (mma_m16n8k16_bf16_f32 (16 , 8 , 16 ));
1271
1528
LAUNCH_TEST (mma_m16n8k16_bf16_f32 (32 , 16 , 32 ));
@@ -1304,6 +1561,6 @@ int main() {
1304
1561
TEST (mma_m16n8k16_bf16_f32);
1305
1562
TEST (mma_m16n8k16_s8_s32);
1306
1563
TEST (mma_m16n8k32_s8_s32);
1307
-
1564
+ TEST (mma_m16n8k16_f16_f16);
1308
1565
return 0 ;
1309
1566
}
0 commit comments