Skip to content

Commit a13736a

Browse files
committed
Reused code for quant/dequant
1 parent c9e1908 commit a13736a

File tree

2 files changed

+46
-89
lines changed

2 files changed

+46
-89
lines changed

bitsandbytes/backends/triton/kernels_8bit_quant.py

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,19 @@
2727
@triton.jit
2828
def dequant_8bit_kernel(
2929
a_ptr,
30-
c_ptr,
31-
quant_ptr,
30+
out_ptr,
31+
code_ptr,
3232
absmax_ptr,
33-
num_paired_elements,
33+
n,
3434
QUANT_BLOCK: tl.constexpr,
3535
SPLIT_SIZE: tl.constexpr,
3636
):
3737
pid = tl.program_id(axis=0)
3838
block_start = pid * SPLIT_SIZE
3939
offsets = block_start + tl.arange(0, SPLIT_SIZE)
40-
mask = offsets < num_paired_elements
41-
42-
a = tl.load(a_ptr + offsets, mask)
43-
a = a.to(tl.uint8)
44-
45-
# apply conversion
46-
scaled_int8 = tl.load(quant_ptr + a, mask)
47-
48-
abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
49-
abs_offsets = offsets // QUANT_BLOCK
50-
mask_blocked = offsets < abs_blocks_lim
51-
52-
absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
53-
# apply scales
54-
out_dq = scaled_int8 * absmax
55-
56-
offs = block_start + tl.arange(0, SPLIT_SIZE)
57-
mask = offs < num_paired_elements
58-
tl.store(c_ptr + offs, out_dq, mask)
40+
mask = offsets < n
41+
out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
42+
tl.store(out_ptr + offsets, out_dq, mask)
5943

6044

6145
def dequant_8bit_blockwise(
@@ -66,21 +50,21 @@ def dequant_8bit_blockwise(
6650
dtype: torch.dtype = None,
6751
out: torch.Tensor = None,
6852
):
69-
number_of_paired_elements = a.numel()
53+
n = a.numel()
7054
if out is None:
7155
if dtype is None:
7256
raise ValueError("If out is None, dtype must be specified")
7357
out = torch.empty_like(a, dtype=dtype, device=a.device)
7458

7559
SPLIT_SIZE = 256
7660
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
77-
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
61+
grid = (triton.cdiv(n, SPLIT_SIZE),)
7862
dequant_8bit_kernel[grid](
7963
a,
8064
out,
8165
quant_state_code,
8266
absmax,
83-
number_of_paired_elements,
67+
n,
8468
quant_blocksize,
8569
SPLIT_SIZE,
8670
)
@@ -115,39 +99,9 @@ def quantize_8bit_blockwise_kernel(
11599

116100
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
117101

118-
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
119-
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))
120-
121-
# Calculating absamax for each block
122-
absmax = tl.max(tl.abs(A_reshaped), axis=1)
102+
quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
123103
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
124-
125-
A_normalized = A_reshaped / absmax[:, None]
126-
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
127-
128-
lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
129-
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
130-
131-
for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
132-
pivot = (lower_pivot + upper_pivot) // 2
133-
val = tl.load(code_ptr + pivot)
134-
is_higher = A_normalized > val # code[pivot]
135-
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
136-
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
137-
138-
# Choose closest level
139-
lower_val = tl.load(code_ptr + lower_pivot)
140-
upper_val = tl.load(code_ptr + upper_pivot)
141-
lower_dist = tl.abs(A_normalized - lower_val)
142-
upper_dist = tl.abs(A_normalized - upper_val)
143-
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
144-
145-
# too slow approach
146-
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
147-
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
148-
149-
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
150-
tl.store(out_ptr + offsets, quantized_flat, mask=mask)
104+
tl.store(out_ptr + offsets, quantized, mask=mask)
151105

152106

153107
def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
@@ -180,17 +134,17 @@ def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
180134

181135

182136
@triton.jit
183-
def quantize_8bit_blockwise_core(
137+
def quantize_8bit_blockwise_kernel_util(
184138
a,
185-
qmap_ptr,
139+
code_ptr,
186140
CODE_SIZE: tl.constexpr,
187141
BLOCK_SIZE: tl.constexpr,
188142
N_PER_TH: tl.constexpr,
189143
):
190144
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
191145
a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))
192146

193-
# Calculating absamax for each block
147+
# Calculating absmax for each block
194148
absmax = tl.max(tl.abs(a_reshaped), axis=1)
195149

196150
a_normalized = a_reshaped / absmax[:, None]
@@ -202,37 +156,40 @@ def quantize_8bit_blockwise_core(
202156
# ceil(log2(code_size)) = 8, actually, in general case should be input parameter
203157
for _ in range(8):
204158
pivot = (lower_pivot + upper_pivot) // 2
205-
val = tl.load(qmap_ptr + pivot)
159+
val = tl.load(code_ptr + pivot)
206160
is_higher = a_normalized > val # code[pivot]
207161
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
208162
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
209163

210164
# Choose closest level
211-
lower_val = tl.load(qmap_ptr + lower_pivot)
212-
upper_val = tl.load(qmap_ptr + upper_pivot)
165+
lower_val = tl.load(code_ptr + lower_pivot)
166+
upper_val = tl.load(code_ptr + upper_pivot)
213167
lower_dist = tl.abs(a_normalized - lower_val)
214168
upper_dist = tl.abs(a_normalized - upper_val)
215169
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
216170

171+
# too slow approach
172+
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
173+
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
174+
217175
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
218176
return quantized_flat, absmax
219177

220178

221179
@triton.jit
222-
def dequant_8bit_kernel_util(
223-
codes_ptr,
180+
def dequant_8bit_blockwise_kernel_util(
181+
a_ptr,
224182
offsets,
225-
qmap_ptr,
183+
code_ptr,
226184
absmax_ptr,
227185
mask,
228186
BLOCK_SIZE: tl.constexpr,
229187
):
230-
codes = tl.load(codes_ptr + offsets, mask, other=0).to(tl.uint8)
231-
abs_offsets = offsets // BLOCK_SIZE
232-
absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
233-
234-
# apply conversion
235-
scaled_int8 = tl.load(qmap_ptr + codes, mask)
236-
# apply scales
188+
a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
189+
scaled_int8 = tl.load(code_ptr + a, mask)
190+
# Load scales
191+
absmax_offsets = offsets // BLOCK_SIZE
192+
absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
193+
# Apply scales
237194
out_dq = scaled_int8 * absmax
238195
return out_dq

bitsandbytes/backends/triton/kernels_optim.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# from triton.language.extra import libdevice
1010
from .kernels_8bit_quant import (
1111
dequant_8bit_blockwise,
12-
dequant_8bit_kernel_util,
13-
quantize_8bit_blockwise_core,
12+
dequant_8bit_blockwise_kernel_util,
13+
quantize_8bit_blockwise_kernel_util,
1414
quantize_blockwise_triton,
1515
)
1616

@@ -445,7 +445,7 @@ def _optimizer_update_1state_8bit_blockwise_triton_kernel(
445445
# 2. Load and dequantize tensors
446446
g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale
447447
p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
448-
s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
448+
s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
449449

450450
# 3. Optimizer-specific updates
451451
# LION
@@ -482,7 +482,7 @@ def _optimizer_update_1state_8bit_blockwise_triton_kernel(
482482

483483
# 4. Store updated parameter and requantized state
484484
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
485-
s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
485+
s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
486486
tl.store(state1_ptr + offsets, s1_codes, mask=mask)
487487
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
488488

@@ -533,8 +533,8 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
533533

534534
# 3. Optimizer-specific updates
535535
if OPTIMIZER_ID == 3: # ADAM
536-
s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
537-
s2 = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
536+
s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
537+
s2 = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
538538

539539
s1 = s1 * beta1 + (1.0 - beta1) * g
540540
s2 = s2 * beta2 + (1.0 - beta2) * g * g
@@ -556,26 +556,26 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
556556
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
557557

558558
# Requantize and store states
559-
s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
559+
s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
560560
tl.store(state1_ptr + offsets, s1_codes, mask=mask)
561561
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
562562

563-
s2_codes, new_absmax2 = quantize_8bit_blockwise_core(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
563+
s2_codes, new_absmax2 = quantize_8bit_blockwise_kernel_util(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
564564
tl.store(state2_ptr + offsets, s2_codes, mask=mask)
565565
tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2)
566566

567567
elif OPTIMIZER_ID == 5: # ADEMAMIX
568568
# AdEMAMix has a stacked state1 (m1, m2) and state2 (nu)
569-
m1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
570-
m2 = dequant_8bit_kernel_util(
569+
m1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
570+
m2 = dequant_8bit_blockwise_kernel_util(
571571
state1_ptr + n_elements,
572572
offsets,
573573
qmap1_ptr,
574574
absmax1_ptr + n_elements // BLOCK_SIZE_N,
575575
mask,
576576
BLOCK_SIZE_N,
577577
)
578-
nu = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
578+
nu = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
579579

580580
m1 = m1 * beta1 + (1.0 - beta1) * g
581581
m2 = m2 * beta3 + (1.0 - beta3) * g
@@ -599,18 +599,18 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
599599
tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
600600

601601
# Requantize and store all three states
602-
m1_codes, new_absmax_m1 = quantize_8bit_blockwise_core(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
602+
m1_codes, new_absmax_m1 = quantize_8bit_blockwise_kernel_util(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
603603
tl.store(state1_ptr + offsets, m1_codes, mask=mask)
604604
tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1)
605605

606-
m2_codes, new_absmax_m2 = quantize_8bit_blockwise_core(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
606+
m2_codes, new_absmax_m2 = quantize_8bit_blockwise_kernel_util(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
607607
tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask)
608608
tl.store(
609609
absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N,
610610
new_absmax_m2,
611611
)
612612

613-
nu_codes, new_absmax_nu = quantize_8bit_blockwise_core(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
613+
nu_codes, new_absmax_nu = quantize_8bit_blockwise_kernel_util(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
614614
tl.store(state2_ptr + offsets, nu_codes, mask=mask)
615615
tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu)
616616

@@ -625,7 +625,7 @@ def _optimizer_update_2state_8bit_blockwise_triton_kernel(
625625
}
626626

627627

628-
def optimizer_update_8bit_blockwise_triton_impl(
628+
def optimizer_update_8bit_blockwise_impl(
629629
optimizer_name: str,
630630
g: torch.Tensor,
631631
p: torch.Tensor,
@@ -703,4 +703,4 @@ def optimizer_update_8bit_blockwise_triton_impl(
703703
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl)
704704
# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant
705705
# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant)
706-
optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_impl
706+
optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_impl

0 commit comments

Comments
 (0)