Skip to content

WingEdge777/vitamin-cuda

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

158 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Vitamin-CUDA

A hands-on CUDA dev learning path from novice to expert.

One Kernel a Day, Keeps High Latency Away. ๐Ÿš€

Welcome to your daily dose of CUDA programming! Vitamin-CUDA is a curated collection of hands-on CUDA practices, designed to take you from Hello World to High Performance. Whether you are a beginner looking to understand the grid-stride loop or an enthusiast diving into warp-level primitives, there's a kernel here for you.

๐Ÿ’ป Let's get started and happy coding! โŒจ๏ธ

News

  • [2026.03.10] sgemm_tf32 tf32 Tensor-Core kernel outperforming cuBLAS cp.async + double smem + swizzle + ldmatrix + mma๐Ÿš€(and stay tuned!)
  • [2026.02.27] sgemm SIMT kernel outperforming cuBLAS with smem + swizzle + double buffer + coalesced r/w๐Ÿš€

Contents ๐Ÿ“–

Prerequisites ๐Ÿ› ๏ธ

  • NVIDIA GPU (Compute Capability 6.0+)
  • CUDA Toolkit 11.0+
  • C++ Compiler (GCC/Clang/MSVC)
  • CMake 3.18+ (Optional, but recommended)
  • PyTorch (For extension examples/python binding and performence comparation)

I recommend using nvidia pytorch ngc docker images for a quick start! Refer to https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch

Kernels (100+ kernels)

All kernels were tested on an RTX 5060 GPU (unless otherwise specified) and benchmarked against PyTorch 2.9

Easy (๐ŸŒŸ~๐ŸŒŸ๐ŸŒŸ)

  • elementwise: elementwise add
    • elementwise_add fp32/fp16 ็‰ˆ
    • elementwise_add_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • elementwise_add_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • elementwise_add_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • sigmoid
    • sigmoid fp32/fp16 ็‰ˆ
    • sigmoid_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • sigmoid_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • sigmoid_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • swish
    • swish fp32/fp16 ็‰ˆ
    • swish_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • swish_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • swish_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • relu
    • relu fp32/fp16 ็‰ˆ
    • relu_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • relu_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • relu_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • relu6
    • relu6 fp32/fp16 ็‰ˆ
    • relu6_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • relu6_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • relu6_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • elu
    • elu fp32/fp16 ็‰ˆ
    • elu_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • elu_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • elu_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w, half2 ่ฟ‘ไธคๅ€ๆๅ‡)
    • pytorch op bindings && diff check
  • gelu
    • gelu fp32/fp16 ็‰ˆ
    • gelu_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • gelu_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • gelu_fp16x8(fp16ๅ‘้‡ๅŒ–๏ผŒpacked r/w)
    • pytorch op bindings && diff check
  • hardswish
    • hardswish fp32/fp16 ็‰ˆ
    • hardswish_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • hardswish_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • hardswish_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • embedding
    • embedding fp32/fp16 ็‰ˆ
    • embedding_fp32x4(fp32ๅ‘้‡ๅŒ–)
    • embedding_fp32x4(fp32ๅ‘้‡ๅŒ–, packed r/w)
    • embedding_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • embedding_fp16x8(fp16ๅ‘้‡ๅŒ–)
    • embedding_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • rope
    • pytorch naive rope
    • pytorch rope with cos/sin table
    • rope fp32 ็‰ˆ (ๆฏ”pytorch naive ๅฎž็Žฐๅฟซไธ€ไธชๆ•ฐ้‡็บง)
    • rope fp32x4 ็‰ˆ (fp32ๅ‘้‡ๅŒ–๏ผŒ็จๅคง่ง„ๆจกๅŽๅฟซๅ‡ ๅๅ€)
    • pytorch op bindings && diff check

Medium (๐ŸŒŸ๐ŸŒŸ~๐ŸŒŸ๐ŸŒŸ๐ŸŒŸ)

  • reduce : ๅŸบไบŽ warp shuffle add
    • reduce_sum fp32/fp16 ็‰ˆ
    • reduce_sum_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • reduce_sum_fp16x8_packed(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • reduce_sum int8 ็‰ˆ
    • reduce_sum_i8x16_packed (int8ๅ‘้‡ๅŒ–๏ผŒpacked r/w)
    • reduce_sum_i8x16_packed (int8ๅ‘้‡ๅŒ–๏ผŒpacked r/w, dp4a, ็›ธๆฏ”torchๆœด็ด ๅฎž็Žฐๅฟซๅ‡ ๅๅ€)
    • reduce_sum_i8x64_packed (int8ๅ‘้‡ๅŒ–๏ผŒpacked r/w, dp4a)
    • pytorch op bindings && diff check
  • dot_product
    • dot_product fp32/fp16 ็‰ˆ
    • dot_product_fp32x4(fp32ๅ‘้‡ๅŒ–)
    • dot_product_fp16x2(fp16ๅ‘้‡ๅŒ–)
    • dot_product_fp16x8(fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • softmax
    • safe online softmax fp32/fp16 ็‰ˆ
    • safe online softmax fp32x4 ็‰ˆ (fp32ๅ‘้‡ๅŒ–)
    • safe online softmax fp16x8 ็‰ˆ (fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • rmsnorm
    • naive torch rmsnorm
    • rmsnorm fp32/fp16 ็‰ˆ
    • rmsnorm fp32x4 ็‰ˆ (fp32ๅ‘้‡ๅŒ–)
    • rmsnorm_fp32x4_smem
    • rmsnorm fp16x8 ็‰ˆ (fp16ๅ‘้‡ๅŒ–, packed r/w)
    • rmsnorm_fp16x8_smem ็‰ˆ (fp16ๅ‘้‡ๅŒ–, packed r/w)
    • pytorch op bindings && diff check
  • transpose
    • transpose_coalesced_read (input่ง†่ง’๏ผŒๅˆๅนถ่ฏป)
    • transpose_coalesced_write (output่ง†่ง’๏ผŒๅˆๅนถๅ†™)
    • transpose_smem (ๅ…ฑไบซๅ†…ๅญ˜็ผ“ๅญ˜๏ผŒๅ—็Šถ่ฏปๅ†™)
    • transpose_smem_bcf (ๅ…ฑไบซๅ†…ๅญ˜ๆ— ๅ†ฒ็ช็‰ˆ)
    • transpose_smem_packed_bcf (ๅ…ฑไบซๅ†…ๅญ˜ๆ— ๅ†ฒ็ช็‰ˆ๏ผŒfloat4ๅ‘้‡ๅŒ–่ฏปๅ†™)
    • transpose_smem_swizzled_packed (ๅ…ฑไบซๅ†…ๅญ˜ๆ— ๅ†ฒ็ช็‰ˆ๏ผŒfloat4ๅ‘้‡ๅŒ–่ฏปๅ†™)
    • pytorch op bindings && diff check

Hard (๐ŸŒŸ๐ŸŒŸ๐ŸŒŸ~๐ŸŒŸ๐ŸŒŸ๐ŸŒŸ๐ŸŒŸ)

  • sgemv
    • gemv fp32็‰ˆ
    • gemv fp32x4๏ผˆๅ‘้‡ๅŒ–่ฏปๅ–๏ผ‰
    • pytorch op bindings && diff check
  • sgemm
    • sgemm_cublas fp32 ็‰ˆ
    • sgemm_tiling (ๅ‘้‡ๅŒ–่ฏปๅ†™ + block tilingๅ…ฑไบซๅ†…ๅญ˜็‰ˆ)
    • sgemm_at_tiling (ๅ‘้‡ๅŒ–่ฏปๅ†™ + a็Ÿฉ้˜ต่ฝฌ็ฝฎๅ†™ๅ…ฅsmem, 4-way ๅ†™ๅ…ฅๅ†ฒ็ช, ๅ†…ๅฑ‚ๅพช็Žฏfloat4่ฏปๅ–)
    • sgemm_at_bcf_swizzling (ๅ‘้‡ๅŒ–่ฏปๅ†™ + at + swizzle๏ผŒ ๆ— ๅ†ฒ็ช็‰ˆ)
    • sgemm_at_bcf_swizzling_rw (ๅ‘้‡ๅŒ–่ฏปๅ†™ + at + swizzle + cๅ†™ๅ›žไบ‹ๅŠกๅˆๅนถ)
    • sgemm_at_bcf_swizzling_dbf_rw(ๅ‘้‡ๅŒ–่ฏปๅ†™ + at + swizzle + cๅ†™ๅ›žไบ‹ๅŠกๅˆๅนถ + double bufferๆตๆฐด็บฟ, ่ถ…่ถŠcuBLAS ๏ผ)
    • pytorch op bindings && diff check
  • sgemm_tf32
    • sgemm_cublas tf32 ็‰ˆ
    • sgemm_tf32_bt (ๅ‘้‡ๅŒ–่ฏปA/B๏ผŒB่ฝฌ็ฝฎๅ†™ๅ…ฅsmem, ldmatrix + mma)
    • sgemm_tf32_bt_swizzle (ๅ‘้‡ๅŒ–่ฏปA/B๏ผŒB่ฝฌ็ฝฎๅ†™ๅ…ฅsmem, ldmatrix + mma, As 0ๅ†ฒ็ช)
    • sgemm_tf32_bt_swizzle_dbf (ๅ‘้‡ๅŒ–่ฏปA/B๏ผŒB่ฝฌ็ฝฎๅ†™ๅ…ฅsmem, ldmatrix + mma, As 0ๅ†ฒ็ช, grid swizzling, 97~102% cuBLAS ๆ€ง่ƒฝ)
    • sgemm_tf32_swizzle_bcf (cp.async่ฏปๅ†™A/B๏ผŒwarp shuffle bๅฏ„ๅญ˜ๅ™จ่ฝฌ็ฝฎ๏ผŒ As/Bsๆ— ๅ†ฒ็ช, grid swizzling)
    • sgemm_tf32_swizzle_bcf_dbf (cp.async่ฏปๅ†™A/B๏ผŒwarp shuffle bๅฏ„ๅญ˜ๅ™จ่ฝฌ็ฝฎ๏ผŒ As/Bsๆ— ๅ†ฒ็ช, grid swizzling๏ผŒๅŒbuffer๏ผŒ่ถ…่ถŠcuBLAS)
    • pytorch op bindings && diff check

Samples

Reference

About

๐ŸŽ One Kernel a Day, Keeps High Latency Away. A hands-on CUDA dev learning path from novice to expert. ๐Ÿš€

Topics

Resources

License

Stars

Watchers

Forks

Contributors