Skip to content

Latest commit

 

History

History
47 lines (30 loc) · 2 KB

File metadata and controls

47 lines (30 loc) · 2 KB

CUDA GEMM Implementation

This repository contains a collection of General Matrix Multiplication (GEMM) kernels implemented in CUDA C++. The project explores various optimization techniques, from basic naive implementations to highly optimized kernels that leverage NVIDIA's Tensor Cores.

Kernels Implemented

The following GEMM kernels are implemented in this project:

  • Naive GEMM: A basic, unoptimized implementation where each thread computes one element of the output matrix.
  • Tiled GEMM: An optimized version that uses shared memory to reduce global memory accesses.
  • CUTE GEMM: A kernel implemented using the CUTE library for flexible tensor operations.
  • MMA GEMM: A kernel that uses the mma.sync PTX instruction for Tensor Core acceleration.
  • WMMA GEMM: A series of kernels that use the nvcuda::wmma API for Tensor Core acceleration, with progressive optimizations:
    • Naive WMMA
    • Increased work per block (mma4x2)
    • Increased work per block (warp2x4)
    • Double buffering and asynchronous data transfers (cp.async)

For a detailed explanation of each kernel, please refer to the docs.md file.

Performance Benchmark

Benchmark Plot

The performance of the implemented kernels was benchmarked against NVIDIA's cuBLAS library. The benchmark was run on an NVIDIA GPU with Tensor Cores.

The plot above shows the TFLOPS achieved by each kernel for different matrix sizes.

As you can see, the hgemm_wmma_mnk16_m4x2_dbuf_async kernel achieves performance that is very close to cuBLAS, demonstrating the effectiveness of the optimization techniques used.

Building and Running

Prerequisites

  • NVIDIA GPU with CUDA support (Compute Capability 8.0+ for Tensor Cores)
  • CUDA Toolkit
  • CMake (version 3.18 or higher)

Running the Benchmark

After building the project, you can run the benchmark from the build directory:

./benchmark

This will run all the implemented kernels and print the performance results to the console.