Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions bitsandbytes/backends/triton/kernels_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def quantize_fp4_blockwise_kernel(

packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out_mask = out_offsets < (n_elements - n_elements // 2)
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)


Expand Down Expand Up @@ -148,7 +149,8 @@ def quantize_nf4_blockwise_kernel(

packed_flat = tl.reshape(packed, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
out_offsets = block_start_idx * BLOCK_SIZE // 2 + tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
out_mask = out_offsets < n_elements // 2
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out_mask = out_offsets < (n_elements - n_elements // 2)
tl.store(out_ptr + out_offsets, packed_flat, mask=out_mask)


Expand Down Expand Up @@ -330,7 +332,14 @@ def dequant_nf4_body_util(a, offsets, absmax_ptr, n_elems, QUANT_BLOCK: tl.const
# )
@triton.jit
def dequant_4bit_kernel(
a_ptr, c_ptr, quant_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
a_ptr,
c_ptr,
quant_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
Expand All @@ -350,7 +359,7 @@ def dequant_4bit_kernel(

out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)


Expand All @@ -367,7 +376,13 @@ def dequant_4bit_kernel(
# )
@triton.jit
def dequant_fp4_kernel(
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
a_ptr,
c_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
Expand All @@ -386,7 +401,7 @@ def dequant_fp4_kernel(

out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)


Expand All @@ -403,7 +418,13 @@ def dequant_fp4_kernel(
# )
@triton.jit
def dequant_nf4_kernel(
a_ptr, c_ptr, absmax_ptr, num_paired_elements, QUANT_BLOCK: tl.constexpr, SPLIT_SIZE: tl.constexpr
a_ptr,
c_ptr,
absmax_ptr,
num_paired_elements,
num_output_elements,
QUANT_BLOCK: tl.constexpr,
SPLIT_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
block_start = pid * SPLIT_SIZE
Expand All @@ -422,7 +443,7 @@ def dequant_nf4_kernel(

out_block_start = pid * SPLIT_SIZE * 2
offs = out_block_start + tl.arange(0, SPLIT_SIZE * 2)
mask = offs < num_paired_elements * 2
mask = offs < num_output_elements
tl.store(c_ptr + offs, out_dq, mask)


Expand All @@ -439,15 +460,16 @@ def dequantize_4bit_impl(
# Elements are in uint8 format, so interleaved
# so total amount of data is 2 * elem_count
number_of_paired_elements = A.numel()
num_output_elements = out.numel()
# we assume that split_size > quant_blocksize

SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
if quant_type == "fp4":
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
dequant_fp4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)
else:
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE)


def dequantize_4bit_impl_passing_code(
Expand All @@ -459,12 +481,15 @@ def dequantize_4bit_impl_passing_code(
out: torch.Tensor,
) -> None:
number_of_paired_elements = A.numel()
num_output_elements = out.numel()
# we assume that split_size > quant_blocksize

SPLIT_SIZE = 256
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META['SPLIT_SIZE']), )
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
dequant_4bit_kernel[grid](A, out, code, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
dequant_4bit_kernel[grid](
A, out, code, absmax, number_of_paired_elements, num_output_elements, blocksize, SPLIT_SIZE
)


######################### Fallback dequantization functions #########################
Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/backends/triton/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def quantize_4bit(
blocks = -(n // -(blocksize * 2))

absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
# Use n - n//2 instead of (n+1)//2 to avoid integer overflow for large n
out = torch.empty((n - n // 2, 1), device=A.device, dtype=torch.uint8)

with torch_accelerator_module.device(A.device):
kernels_4bit.quantize_4bit_blockwise_triton(
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/xpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _dequantize_4bit_impl(
get_ptr(absmax),
get_ptr(out),
ct.c_int(blocksize),
ct.c_int(out.numel()),
ct.c_int64(out.numel()),
_get_tensor_stream(A),
)
if dtype == torch.bfloat16:
Expand Down
36 changes: 18 additions & 18 deletions csrc/pythonInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,57 +303,57 @@ void spmm_coo_very_sparse_naive_int8(
#if BUILD_XPU

void dequantizeBlockwise_fp16(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise<sycl::half, General8bit>(code, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_fp16_fp4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise<sycl::half, FP4>(nullptr, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_fp16_nf4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise<sycl::half, NF4>(nullptr, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_fp32(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_fp32_fp4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise<float, FP4>(nullptr, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_fp32_nf4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise<float, NF4>(nullptr, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_bf16(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
) {
dequantizeBlockwise<sycl::ext::oneapi::bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_bf16_fp4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
) {
dequantizeBlockwise<sycl::ext::oneapi::bfloat16, FP4>(nullptr, A, absmax, out, blocksize, n, stream);
}

void dequantizeBlockwise_bf16_nf4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
) {
dequantizeBlockwise<sycl::ext::oneapi::bfloat16, NF4>(nullptr, A, absmax, out, blocksize, n, stream);
Expand Down Expand Up @@ -728,57 +728,57 @@ void cgemm_4bit_inference_naive_fp32(
#if BUILD_XPU

void cdequantize_blockwise_fp16_fp4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_fp16(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_fp16_nf4(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_fp32(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_fp32_fp4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_fp32_nf4(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
) {
dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_bf16(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
) {
dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_bf16_fp4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
) {
dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream);
}

void cdequantize_blockwise_bf16_nf4(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
) {
dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream);
Expand Down
14 changes: 7 additions & 7 deletions csrc/xpu_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,20 @@ inline float dDequantizeNF4(unsigned char val) {

template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE>
SYCL_EXTERNAL void kDequantizeBlockwise<T, TILE_SIZE, NUM_PER_TH, DATA_TYPE>::operator()(sycl::nd_item<1> item) const {
const int base_idx = item.get_group(0) * TILE_SIZE;
size_t local_idx = item.get_local_id(0) * NUM_PER_TH;
const int64_t base_idx = static_cast<int64_t>(item.get_group(0)) * TILE_SIZE;
int64_t local_idx = static_cast<int64_t>(item.get_local_id(0)) * NUM_PER_TH;
float local_abs_max = -FLT_MAX;
int local_load_idx = 0;
int local_store_idx = 0;
int64_t local_load_idx = 0;
int64_t local_store_idx = 0;

uint8_t qvals[NUM_PER_TH];
T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];

if (DATA_TYPE > 0) {
local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx);
local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2);
local_load_idx = sycl::min(static_cast<int64_t>(TILE_SIZE), (n + 1) / 2 - base_idx);
local_store_idx = sycl::min(static_cast<int64_t>(TILE_SIZE * 2), n - base_idx * 2);
} else {
local_load_idx = sycl::min(TILE_SIZE, n - base_idx);
local_load_idx = sycl::min(static_cast<int64_t>(TILE_SIZE), n - base_idx);
local_store_idx = local_load_idx;
}

Expand Down
4 changes: 2 additions & 2 deletions csrc/xpu_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE> class kDequa
public:
SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;

kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_)
kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int64_t n_)
: code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {}

private:
Expand All @@ -17,7 +17,7 @@ template <typename T, int TILE_SIZE, int NUM_PER_TH, int DATA_TYPE> class kDequa
float* absmax;
T* out;
const int blocksize;
const int n;
const int64_t n;
};

template <typename T, size_t GROUP_SIZE, size_t NUM_PER_THREAD, size_t SUBG_SIZE, int BITS> class kgemv_4bit_inference {
Expand Down
20 changes: 10 additions & 10 deletions csrc/xpu_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int64_t n, sycl::queue* stream
) {
auto& queue = *stream;
const int workgroup_size = 128;
Expand Down Expand Up @@ -55,35 +55,35 @@ void gemv_4bit_inference(
//==============================================================

template void dequantizeBlockwise<float, General8bit>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
);
template void dequantizeBlockwise<float, FP4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
);
template void dequantizeBlockwise<float, NF4>(
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int64_t n, sycl::queue* stream
);

template void dequantizeBlockwise<sycl::half, General8bit>(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
);
template void dequantizeBlockwise<sycl::half, FP4>(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
);
template void dequantizeBlockwise<sycl::half, NF4>(
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int64_t n, sycl::queue* stream
);

template void dequantizeBlockwise<sycl::ext::oneapi::bfloat16, General8bit>(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
);
template void dequantizeBlockwise<sycl::ext::oneapi::bfloat16, FP4>(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
);
template void dequantizeBlockwise<sycl::ext::oneapi::bfloat16, NF4>(
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int64_t n,
sycl::queue* stream
);

Expand Down
2 changes: 1 addition & 1 deletion csrc/xpu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ static inline void sycl_comp_kernel_submit(sycl::nd_range<dim> range, sycl::queu

template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream
float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int64_t n, sycl::queue* stream
);
template <typename T, int BITS>
void gemv_4bit_inference(
Expand Down