Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions gpu/bitnet_kernels/bitgemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <stdint.h>
#include <iostream>
#include <stdlib.h>
#include <string.h>
#include <time.h>

template <typename T1, typename T2>
__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16) {
// convert 8 int2b_t to 8 int8b_t -> 2 int32
uint *i8s = reinterpret_cast<uint *>(_i8s);

// i2s = {e7,e6,e5,e4,e3,e2,e1,e0}
// also require interleave {e7,e3,e6,e2,e5,e1,e4,e0}
uint const i2s = *_i2s;

// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000; // 1024

#pragma unroll
for (int i = 0; i < (N / 4); i++) {
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(i8s[i])
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK),
"n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
i8s[i] = __vsubss4(i8s[i], 0x02020202);
}
}


template <int N, int K>
__global__ void int8_int2_gemm_tensor_core(
const int8_t *__restrict__ A, // M x K matrix, row-major
const int32_t *__restrict__ B_compressed, // Compressed int2 data for N x K matrix, column-major
int32_t *__restrict__ C, // M x N output matrix, row-major
int M)
{
// Define WMMA dimensions - all constant
constexpr int WMMA_M = 16;
constexpr int WMMA_N = 16;
constexpr int WMMA_K = 16;

// Define block tile dimensions - all constant
constexpr int BLOCK_SIZE_M = 64; // Multiple of WMMA_M
constexpr int BLOCK_SIZE_N = 64; // Multiple of WMMA_N
constexpr int BLOCK_SIZE_K = 32; // K dimension as requested

// Calculate thread block position
const int blockM = blockIdx.y * BLOCK_SIZE_M;
const int blockN = blockIdx.x * BLOCK_SIZE_N;

// Calculate thread ID and warp IDs
const int warpM = threadIdx.y; // 0-1 (2 warps in M dimension)
const int warpN = threadIdx.z; // 0-1 (2 warps in N dimension)
const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;

// Add padding to shared memory to avoid bank conflicts
constexpr int PAD_A = 16; // Padding for A matrix
constexpr int PAD_B = 16; // Padding for B matrix

// Allocate shared memory for A and B matrices with padding
__shared__ int8_t shared_A[BLOCK_SIZE_M][BLOCK_SIZE_K + PAD_A];
__shared__ int8_t shared_B[BLOCK_SIZE_N][BLOCK_SIZE_K + PAD_B];

// Define fragments for all tiles this warp will handle - static allocation
nvcuda::wmma::fragment<nvcuda::wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, int32_t> c_frags[2][2];
nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, int8_t, nvcuda::wmma::row_major> a_frag;
nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, int8_t, nvcuda::wmma::col_major> b_frag;

// Initialize all accumulator fragments to zero (unrolled)
#pragma unroll
for (int m_iter = 0; m_iter < 2; m_iter++) {
#pragma unroll
for (int n_iter = 0; n_iter < 2; n_iter++) {
nvcuda::wmma::fill_fragment(c_frags[m_iter][n_iter], 0);
}
}

// Only check M bounds at the beginning
const bool m_valid = blockM < M;

// Loop over K dimension in chunks of BLOCK_SIZE_K
#pragma unroll 4 // Partial unroll of K-dimension loop
for (int k_block = 0; k_block < K; k_block += BLOCK_SIZE_K) {
// Clear shared memory first
__syncthreads();

// Load A matrix tiles into shared memory using vectorized loads
// Each thread handles multiple elements based on its ID
for (int load_idx = tid; load_idx < (BLOCK_SIZE_M * BLOCK_SIZE_K / 16); load_idx += blockDim.x * blockDim.y * blockDim.z) {
int local_m = (load_idx * 16) / BLOCK_SIZE_K;
int local_k = (load_idx * 16) % BLOCK_SIZE_K;

int global_m = blockM + local_m;
int global_k = k_block + local_k;

// Use vector loads for A - 16 bytes at a time (int4 = 4 integers = 16 bytes)
if (m_valid && global_m < M) {
// Vector load from A to shared memory
*((int4*)&shared_A[local_m][local_k]) = *((int4*)&A[global_m * K + global_k]);
} else {
// Zero out if M is out of bounds
*((int4*)&shared_A[local_m][local_k]) = {0};
}
}

// Load B matrix tiles into shared memory (always in bounds for N and K)
// Calculate which 16-element chunk this thread is responsible for
int chunk_n = (tid * 16 / BLOCK_SIZE_K);
int chunk_k = (tid * 16) % BLOCK_SIZE_K;

if (chunk_n < BLOCK_SIZE_N) {
int global_n = blockN + chunk_n;
int global_k = k_block + chunk_k;

// Calculate which compressed block this belongs to
int n_block = global_n / 16;
int k_block_32 = global_k / 32;
int k_offset_in_block = chunk_k % 32;

// Get the specific compressed tile within the 16x32 block
int in_block_n = chunk_n % 16;
int compressed_block_idx = n_block * (K / 32) + k_block_32;

// Calculate which tile within the compressed block
int tile_idx;
tile_idx = in_block_n / 8 * 16 + in_block_n % 8 + (k_offset_in_block / 16) * 8;

// Extract and decompress the int2 values
int32_t compressed = B_compressed[compressed_block_idx * 32 + tile_idx];
int8_t decompressed[16];
decode_i2s_to_i8s(&compressed, decompressed);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many threads will dequant i2s with same weight, could we create a pre-process to cache the dequant result to share memory

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, blocks with the same blockN but different blockM will dequant the same weight. However, shared memory is only accessible to threads within the same thread block. If we want to cache the dequant result, I think we either dequant all weights in global memory, or we have to loop on M in every block to reuse the weight, which may lead to splitK to maximize parallel. How do you think we can implement this to optimize?


// Vector store to shared memory
*((int4*)&shared_B[chunk_n][chunk_k]) = *((int4*)decompressed);
}

// Make sure all threads have finished loading into shared memory
__syncthreads();

// Process the 2x2 WMMA tiles for this K block
#pragma unroll
for (int m_iter = 0; m_iter < 2; m_iter++) {
#pragma unroll
for (int n_iter = 0; n_iter < 2; n_iter++) {
// Calculate the starting positions for this WMMA tile
#pragma unroll
for (int wmma_k = 0; wmma_k < BLOCK_SIZE_K; wmma_k += WMMA_K) {
// Fully unroll the m and n iterations
const int tile_m = (warpM * 2 + m_iter) * WMMA_M;
const int tile_n = (warpN * 2 + n_iter) * WMMA_N;

// Load matrix A fragment from shared memory with padding
nvcuda::wmma::load_matrix_sync(
a_frag, &shared_A[tile_m][wmma_k], BLOCK_SIZE_K + PAD_A);

// Load matrix B fragment from shared memory with padding
nvcuda::wmma::load_matrix_sync(
b_frag, &shared_B[tile_n][wmma_k], BLOCK_SIZE_K + PAD_B);

// Perform matrix multiplication
nvcuda::wmma::mma_sync(c_frags[m_iter][n_iter], a_frag, b_frag, c_frags[m_iter][n_iter]);
}
}
}
}

// Store results back to global memory - only check M bounds
#pragma unroll
for (int m_iter = 0; m_iter < 2; m_iter++) {
const int tile_m = (warpM * 2 + m_iter) * WMMA_M;
const int global_tile_m = blockM + tile_m;

if (m_valid && global_tile_m < M) {
#pragma unroll
for (int n_iter = 0; n_iter < 2; n_iter++) {
const int tile_n = (warpN * 2 + n_iter) * WMMA_N;
const int global_tile_n = blockN + tile_n;

// No need to check N bounds as it's always aligned
nvcuda::wmma::store_matrix_sync(
&C[global_tile_m * N + global_tile_n],
c_frags[m_iter][n_iter], N, nvcuda::wmma::mem_row_major);
}
}
}
}

extern "C" void bitlinear_int8xint2(int8_t *input0, int8_t *input1,
int32_t *output0, int M, int N, int K,
cudaStream_t stream = 0) {
if (N == 3840 && K == 2560) {
int8_int2_gemm_tensor_core<3840, 2560>
<<<dim3(60, (M + 63) / 64, 1), dim3(32, 2, 2), 0, stream>>>(
input0, (int32_t *)input1, (int32_t *)output0, M);
} else if (N == 2560 && K == 2560) {
int8_int2_gemm_tensor_core<2560, 2560>
<<<dim3(40, (M + 63) / 64, 1), dim3(32, 2, 2), 0, stream>>>(
input0, (int32_t *)input1, (int32_t *)output0, M);
} else if (N == 13824 && K == 2560) {
int8_int2_gemm_tensor_core<13824, 2560>
<<<dim3(216, (M + 63) / 64, 1), dim3(32, 2, 2), 0, stream>>>(
input0, (int32_t *)input1, (int32_t *)output0, M);
} else if (N == 2560 && K == 6912) {
int8_int2_gemm_tensor_core<2560, 6912>
<<<dim3(40, (M + 63) / 64, 1), dim3(32, 2, 2), 0, stream>>>(
input0, (int32_t *)input1, (int32_t *)output0, M);
} else {
std::cerr << "Error: Unsupported matrix dimensions for bitlinear_int8xint2. "
<< "Required kernel: M=" << M << ", N=" << N << ", K=" << K << std::endl;
std::cerr << "Supported configurations:" << std::endl;
std::cerr << " - N=3840, K=2560" << std::endl;
std::cerr << " - N=2560, K=2560" << std::endl;
std::cerr << " - N=13824, K=2560" << std::endl;
std::cerr << " - N=2560, K=6912" << std::endl;
throw std::runtime_error("Unsupported matrix dimensions for bitlinear_int8xint2");
}

// Check for CUDA launch errors
cudaError_t launch_error = cudaGetLastError();
if (launch_error != cudaSuccess) {
std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(launch_error) << std::endl;
throw std::runtime_error("CUDA kernel launch failed");
}
}
2 changes: 1 addition & 1 deletion gpu/bitnet_kernels/bitnet_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
if (M == 1 && N == 3840 && K == 2560){
ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
ladder_int8xint2_kernel<1, 3840, 2560, 6, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
}
else if (M == 1 && N == 2560 && K == 2560){
ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
Expand Down
1 change: 1 addition & 0 deletions gpu/bitnet_kernels/compile.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so
nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitgemm.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libgemm.so


6 changes: 3 additions & 3 deletions gpu/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def convert_int8_to_int2(weight):
wk_weight, wb_scale = quant_weight_int8(wk)
wv_weight, wc_scale = quant_weight_int8(wv)
wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0)
wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0)
wqkv_scale = torch.cat([wa_scale, wa_scale, wa_scale, wa_scale, wb_scale, wc_scale], dim=0)
int2_result[key] = convert_int8_to_int2(wqkv_weight)
int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale

Expand All @@ -62,7 +62,7 @@ def convert_int8_to_int2(weight):
w1_weight, w1_scale = quant_weight_int8(w1)
w3_weight, w3_scale = quant_weight_int8(w3)
w13_weight = torch.cat([w1_weight, w3_weight], dim=0)
w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0)
w13_scale = torch.cat([w1_scale, w3_scale, zero, zero, zero, zero], dim=0)
int2_result[key] = convert_int8_to_int2(w13_weight)
int2_result[key.replace('weight', 'weight_scale')] = w13_scale

Expand All @@ -72,7 +72,7 @@ def convert_int8_to_int2(weight):
fp16_result[key] = w13_weight
elif 'w2' in key or 'wo' in key:
weight, scale = quant_weight_int8(value)
scale = torch.cat([scale, zero, zero, zero], dim=0)
scale = torch.cat([scale, zero, zero, zero, zero, zero], dim=0)
int2_result[key] = convert_int8_to_int2(weight)
int2_result[key.replace('weight', 'weight_scale')] = scale

Expand Down
6 changes: 2 additions & 4 deletions gpu/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build(
"""
start_time = time.time()

model_args_prefill = fast.ModelArgs(use_kernel=False)
model_args_prefill = fast.ModelArgs(use_kernel=True)
model_args_decode = fast.ModelArgs(use_kernel=True)
tokenizer = Tokenizer("./tokenizer.model")

Expand All @@ -63,11 +63,9 @@ def build(
prefill_model = fast.Transformer(model_args_prefill)
decode_model = fast.Transformer(model_args_decode)

fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
prefill_model.load_state_dict(int2_checkpoint, strict=True)
decode_model.load_state_dict(int2_checkpoint, strict=True)

torch.cuda.synchronize()
Expand Down
48 changes: 45 additions & 3 deletions gpu/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

import ctypes
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
gemm_lib = ctypes.CDLL('bitnet_kernels/libgemm.so')

def bitnet_int8xint2_linear(input0, input1, s, ws):
import numpy as np

def bitnet_int8xint2_linear_gemv(input0, input1, s, ws):
out_shape = list(input0.shape)
out_shape[-1] = input1.shape[0]

Expand All @@ -36,6 +39,42 @@ def bitnet_int8xint2_linear(input0, input1, s, ws):

return ret

def bitnet_int8xint2_linear_gemm(input0, input1, s, ws):
out_shape = list(input0.shape)
out_shape[-1] = input1.shape[0]

stream = torch.cuda.current_stream()

M = input0.shape[0]
if len(out_shape) == 3:
M *= input0.shape[1]
N = input1.shape[0]
K = input1.shape[1] * 4

ret = torch.zeros(*out_shape, dtype=torch.int32, device=input0.device)

gemm_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
ret = ret.to(torch.bfloat16)
ret = ret / s
if N == 3840 and K == 2560:
#split last dim to 6 parts evenly
ret = ret.reshape(*ret.shape[:-1], 6, -1)
# devide each part by first 6 coresponding weight scale
ret = ret * ws[:6].reshape(1, 6, 1)
elif (N == 2560 and K == 2560):
# 1 part
ret = ret* ws[:1].reshape(1, 1, 1, 1)
elif (N == 13824 and K == 2560):
# 2 parts
ret = ret.reshape(*ret.shape[:-1], 2, -1)
# devide each part by first 2 coresponding weight scale
ret = ret * ws[:2].reshape(1, 1, 2, 1)
elif (N == 2560 and K == 6912):
# 1 part
ret = ret * ws[:1].reshape(1, 1, 1, 1)

return ret.reshape(*out_shape)

@dataclass
class ModelArgs:
dim: int = 2560
Expand Down Expand Up @@ -63,7 +102,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool = False):
self.out_features = out_features

self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False)
self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False)
self.weight_scale = torch.nn.Parameter(torch.zeros(6, dtype=torch.bfloat16), requires_grad=False)

@torch.compile
def quant_input(self, input):
Expand All @@ -72,7 +111,10 @@ def quant_input(self, input):

def forward(self, input):
input, s = self.quant_input(input)
return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale)
if input.shape[0] == 1:
return bitnet_int8xint2_linear_gemv(input, self.weight, s, self.weight_scale)
else:
return bitnet_int8xint2_linear_gemm(input, self.weight, s, self.weight_scale)

class BitLinear(nn.Linear):
@torch.compile
Expand Down
Loading