Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 258 additions & 1 deletion features/feature_case/asm/asm_mma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,42 @@ __host__ void initialize_matrices(int8_t *A, int8_t *B, int *C, int *D, int M, i
}


__host__ void initialize_matrices(half *A, half *B, half *C, half *D, int M,
int N, int K, int A_NUM_MAT, int B_NUM_MAT,
int CD_NUM_MAT) {
for (int N_MAT = 0; N_MAT < A_NUM_MAT; N_MAT++) {
int A_OFFSET = N_MAT * M * K;

for (int i = 0; i < M; ++i) {
for (int j = 0; j < K; ++j) {
A[A_OFFSET + i * K + j] = __float2half(((i * K + j) / 4.0f));
}
}
}

for (int N_MAT = 0; N_MAT < B_NUM_MAT; N_MAT++) {
int B_OFFSET = N_MAT * K * N;

for (int i = 0; i < K; ++i) {
for (int j = 0; j < N; ++j) {
B[B_OFFSET + i * N + j] = __float2half(((i * N + j) / 4.0f));
}
}
}

for (int N_MAT = 0; N_MAT < CD_NUM_MAT; N_MAT++) {
int CD_OFFSET = N_MAT * M * N;

for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
C[CD_OFFSET + i * N + j] = __float2half((i * N + j) * 1.0f);
D[CD_OFFSET + i * N + j] = 0.0f;
}
}
}
}


template <typename ABType, typename CDType>
void matrix_multiplication_cpu(ABType *A, ABType *B, CDType *C, CDType *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
for (int N_MAT = 0; N_MAT < CD_NUM_MAT; N_MAT++) {
Expand All @@ -114,7 +150,44 @@ void matrix_multiplication_cpu(ABType *A, ABType *B, CDType *C, CDType *D, int M
}
}
}
/*
1. Properties of half (float16)
Limited precision: float16 has only 5 bits for the exponent and 10 bits for the
mantissa, giving roughly 3 decimal digits of precision.

Non-uniform precision: The closer to zero, the higher the precision; the closer
to the maximum value, the lower the precision.

2. Difference of CPU vs. GPU on half support
* On the GPU (e.g., Tensor Core or native FP16 ALUs):
Many GPUs, especially NVIDIA ones, support native half x half operations,
either:

half x half -> half or half x half -> float

Using Tensor Cores or PTX instructions like mma.sync, often:

half x half -> float (accumulate) -> rounded to half at the end

So the GPU may retain intermediate precision in float, which can differ from a
float-roundtrip on the CPU.

* On the CPU:
CPUs usually don't have native half-precision units, so the operation goes
like:

Convert half to float -> multiply as float -> convert result back to half

This involves two rounding steps:

From half to float (no loss)

From float to half (can introduce significant rounding errors)

For example, if you multiply 60000 x 2 as float, you get 120000. But float16
can't represent 120000, so it becomes inf.
*/
// Note: The results of cpu vs. gpu is different, especially for larger number.
void matrix_multiplication_cpu(half *A, half *B, half *C, half *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
for (int N_MAT = 0; N_MAT < CD_NUM_MAT; N_MAT++) {
int A_OFFSET = (N_MAT % A_NUM_MAT) * (M * K);
Expand Down Expand Up @@ -480,6 +553,114 @@ __global__ void mma_kernel_m16n8k16_ptx_f16_f32(half *A, half *B, float *C, floa
}


template <int Shape_M, int Shape_N, int Shape_K>
__global__ void mma_kernel_m16n8k16_ptx_f16_f16(half *A, half *B, half *C,
half *D, int M, int N, int K,
int A_NUM_MAT, int B_NUM_MAT,
int CD_NUM_MAT) {
const int THREAD_IDX = threadIdx.x + blockIdx.x * blockDim.x;
const int WARP_ID = THREAD_IDX / WARP_SIZE;
const int LANE_ID = THREAD_IDX % WARP_SIZE;

const int THREAD_ROW = LANE_ID / 4;
const int THREAD_COL = LANE_ID % 4;

int A_OFFSET = (WARP_ID % A_NUM_MAT) * Shape_M * Shape_K;
int B_OFFSET = (WARP_ID % B_NUM_MAT) * Shape_K * Shape_N;
int CD_OFFSET = (WARP_ID % CD_NUM_MAT) * Shape_M * Shape_N;

uint32_t a[4];
uint32_t b[2];
uint32_t c[2];
uint32_t d[2];

auto ra = reinterpret_cast<half *>(a);
auto rb = reinterpret_cast<half *>(b);
auto rc = reinterpret_cast<half *>(c);
auto rd = reinterpret_cast<half *>(d);

for (int i = 0; i < 8; i++) {
int r_off = 8;
if (i < 2 || (i >= 4 && i < 6)) {
r_off = 0;
}

int c_off = 0;
if (i >= 4) {
c_off = 8;
}

int load_offset =
A_OFFSET + OFFSET(THREAD_ROW + r_off,
(THREAD_COL * 2) + (i & 0x1) + c_off, Shape_K);
if (IN_BOUND_A(load_offset)) {
ra[i] = A[load_offset];

}
}

for (int i = 0; i < 4; i++) {
int r_off = 0;
if (i >= 2) {
r_off = 8;
}

int load_offset = B_OFFSET + OFFSET((THREAD_COL * 2) + (i & 0x1) + r_off,
THREAD_ROW, Shape_N);
if (IN_BOUND_B(load_offset)) {
rb[i] = B[load_offset];
}
}

for (int i = 0; i < 4; i++) {
int load_offset;

if (i < 2) {
load_offset =
CD_OFFSET + OFFSET(THREAD_ROW, (THREAD_COL * 2) + (i & 0x1), Shape_N);
} else {
load_offset = CD_OFFSET + OFFSET(THREAD_ROW + 8,
(THREAD_COL * 2) + (i & 0x1), Shape_N);
}

if (IN_BOUND_CD(load_offset)) {
rc[i] = C[load_offset];
}
}


asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
" { %0, %1}, "
" { %2, %3, %4, %5 }, "
" { %6, %7 }, "
" { %8, %9 };"
: "+r"(d[0]), "+r"(d[1])
: "r"(*(reinterpret_cast<int *>(&a[0]))),
"r"(*(reinterpret_cast<int *>(&a[1]))),
"r"(*(reinterpret_cast<int *>(&a[2]))),
"r"(*(reinterpret_cast<int *>(&a[3]))),
"r"(*(reinterpret_cast<int *>(&b[0]))),
"r"(*(reinterpret_cast<int *>(&b[1]))),
"r"(c[0]), "r"(c[1]));

for (int i = 0; i < 4; i++) {
int load_offset;

if (i < 2) {
load_offset =
CD_OFFSET + OFFSET(THREAD_ROW, (THREAD_COL * 2) + (i & 0x1), Shape_N);
} else {
load_offset = CD_OFFSET + OFFSET(THREAD_ROW + 8,
(THREAD_COL * 2) + (i & 0x1), Shape_N);
}

if (IN_BOUND_CD(load_offset)) {
D[load_offset] = rd[i];
}
}
}


template <int Shape_M, int Shape_N, int Shape_K>
__global__ void mma_kernel_m16n8k16_ptx_bf16_f32(__nv_bfloat16 *A, __nv_bfloat16 *B, float *C, float *D, int M, int N, int K, int A_NUM_MAT, int B_NUM_MAT, int CD_NUM_MAT) {
const int THREAD_IDX = threadIdx.x + blockIdx.x * blockDim.x;
Expand Down Expand Up @@ -979,6 +1160,60 @@ bool run_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
return correct;
}

bool run_test_mma_m16n8k16_f16_f16(const int M, const int N, const int K) {
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
calculate_num_matrices<16, 8, 16>(M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);

if (A_NUM_MAT == 0 || B_NUM_MAT == 0 || CD_NUM_MAT == 0) {
std::cerr << "Matrix dimensions are not compatible with m16n8k16" << std::endl;
return false;
}

half *d_A, *d_B;
half *d_C, *d_D;
half h_A[M * K], h_B[K * N];
half h_C[M * N], h_D[M * N], h_D_ref[M * N];

initialize_matrices(h_A, h_B, h_C, h_D, 16, 8, 16, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);

matrix_multiplication_cpu(h_A, h_B, h_C, h_D_ref, 16, 8, 16, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);

cudaMalloc(&d_A, M * K * sizeof(half));
cudaMalloc(&d_B, K * N * sizeof(half));
cudaMalloc(&d_C, M * N * sizeof(half));
cudaMalloc(&d_D, M * N * sizeof(half));

cudaMemcpy(d_A, h_A, M * K * sizeof(half), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, h_B, K * N * sizeof(half), cudaMemcpyHostToDevice);
cudaMemcpy(d_C, h_C, M * N * sizeof(half), cudaMemcpyHostToDevice);
cudaMemcpy(d_D, h_D, M * N * sizeof(half), cudaMemcpyHostToDevice);

int no_mat_blocks = 4;
int no_blocks = CD_NUM_MAT / no_mat_blocks;
int no_threads;
if (no_blocks) {
no_threads = WARP_SIZE * no_mat_blocks;
} else {
no_blocks = 1;
no_threads = WARP_SIZE * CD_NUM_MAT;
}

mma_kernel_m16n8k16_ptx_f16_f16<16, 8, 16><<<no_blocks, no_threads>>>(d_A, d_B, d_C, d_D, M, N, K, A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT);
cudaDeviceSynchronize();
cudaMemcpy(h_D, d_D, M * N * sizeof(half), cudaMemcpyDeviceToHost);

bool correct = check_result(M, N, h_D, h_D_ref);

std::cout << "m16n8k16 (f16.f16.f16.f16): " << (correct ? "PASSED" : "FAILED") << std::endl;

cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
cudaFree(d_D);

return correct;
}


bool run_test_mma_m16n8k16_bf16_f32(const int M, const int N, const int K) {
int A_NUM_MAT, B_NUM_MAT, CD_NUM_MAT;
Expand Down Expand Up @@ -1181,6 +1416,18 @@ bool launch_test_mma_m16n8k8_f16_f32(const int M, const int N, const int K) {
return true;
}

bool launch_test_mma_m16n8k16_f16_f16(const int M, const int N, const int K) {
bool correct = run_test_mma_m16n8k16_f16_f16(M, N, K);

if (!correct) {
std::cerr << "m16n8k16 (f16.f16.f16.f16) failed for dims: " << M << ", " << N
<< ", " << K << std::endl;
return false;
}

return true;
}

bool launch_test_mma_m16n8k16_f16_f32(const int M, const int N, const int K) {
bool correct = run_test_mma_m16n8k16_f16_f32(M, N, K);

Expand Down Expand Up @@ -1266,6 +1513,16 @@ bool mma_m16n8k16_f16_f32() {
return true;
}

bool mma_m16n8k16_f16_f16() {
LAUNCH_TEST(mma_m16n8k16_f16_f16(16, 8, 16));
LAUNCH_TEST(mma_m16n8k16_f16_f16(32, 16, 32));
LAUNCH_TEST(mma_m16n8k16_f16_f16(16, 16, 16));
LAUNCH_TEST(mma_m16n8k16_f16_f16(16, 16, 32));
LAUNCH_TEST(mma_m16n8k16_f16_f16(32, 32, 32));

return true;
}

bool mma_m16n8k16_bf16_f32() {
LAUNCH_TEST(mma_m16n8k16_bf16_f32(16, 8, 16));
LAUNCH_TEST(mma_m16n8k16_bf16_f32(32, 16, 32));
Expand Down Expand Up @@ -1304,6 +1561,6 @@ int main() {
TEST(mma_m16n8k16_bf16_f32);
TEST(mma_m16n8k16_s8_s32);
TEST(mma_m16n8k32_s8_s32);

TEST(mma_m16n8k16_f16_f16);
return 0;
}