Skip to content

Commit 404e277

Browse files
authored
[XPU] Implemented 8bit optimizers in triton (#1692)
* implemented 8bit optimizers * Add interface * Commented out torch checks * Merged * Updated kernels * Reused code for quant/dequant * Removed empty line * Changed Readme
1 parent 4b02574 commit 404e277

File tree

6 files changed

+994
-197
lines changed

6 files changed

+994
-197
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ bitsandbytes has the following minimum requirements for all platforms:
141141
</td>
142142
<td>✅</td>
143143
<td>✅</td>
144-
<td>🚧</td>
144+
<td>〰️</td>
145145
</tr>
146146
<tr>
147147
<td colspan="6">🍎 <strong>macOS 14+</strong></td>

bitsandbytes/backends/triton/triton_kernels.py renamed to bitsandbytes/backends/triton/kernels_4bit.py

Lines changed: 2 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -4,167 +4,6 @@
44
import triton.language as tl
55

66

7-
# @triton.autotune(
8-
# configs=[
9-
# # triton.Config({'SPLIT_SIZE': 64}),
10-
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
11-
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
12-
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
13-
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
14-
# # triton.Config({'SPLIT_SIZE': 128}),
15-
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
16-
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
17-
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
18-
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
19-
# triton.Config({"SPLIT_SIZE": 256}),
20-
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
21-
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
22-
# triton.Config({"SPLIT_SIZE": 512}),
23-
# # triton.Config({'SPLIT_SIZE': 1024}),
24-
# ],
25-
# key=["num_paired_elements", "QUANT_BLOCK"],
26-
# )
27-
@triton.jit
28-
def dequant_8bit_kernel(
29-
a_ptr,
30-
c_ptr,
31-
quant_ptr,
32-
absmax_ptr,
33-
num_paired_elements,
34-
QUANT_BLOCK: tl.constexpr,
35-
SPLIT_SIZE: tl.constexpr,
36-
):
37-
pid = tl.program_id(axis=0)
38-
block_start = pid * SPLIT_SIZE
39-
offsets = block_start + tl.arange(0, SPLIT_SIZE)
40-
mask = offsets < num_paired_elements
41-
42-
a = tl.load(a_ptr + offsets, mask)
43-
a = a.to(tl.uint8)
44-
45-
# apply conversion
46-
scaled_int8 = tl.load(quant_ptr + a, mask)
47-
48-
abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
49-
abs_offsets = offsets // QUANT_BLOCK
50-
mask_blocked = offsets < abs_blocks_lim
51-
52-
absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
53-
# apply scales
54-
out_dq = scaled_int8 * absmax
55-
56-
offs = block_start + tl.arange(0, SPLIT_SIZE)
57-
mask = offs < num_paired_elements
58-
tl.store(c_ptr + offs, out_dq, mask)
59-
60-
61-
def dequant_int8_blockwise(
62-
A_nf4: torch.Tensor,
63-
quant_state_code: torch.Tensor,
64-
absmax: torch.Tensor,
65-
out: torch.Tensor,
66-
quant_blocksize: int = 64,
67-
):
68-
number_of_paired_elements = A_nf4.numel()
69-
70-
SPLIT_SIZE = 256
71-
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
72-
grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
73-
dequant_8bit_kernel[grid](
74-
A_nf4,
75-
out,
76-
quant_state_code,
77-
absmax,
78-
number_of_paired_elements,
79-
quant_blocksize,
80-
SPLIT_SIZE,
81-
)
82-
return out
83-
84-
85-
# @triton.autotune(
86-
# configs=[
87-
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
88-
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
89-
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
90-
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
91-
# ],
92-
# key=["n_elements"],
93-
# )
94-
@triton.jit
95-
def quantize_blockwise_kernel(
96-
A_ptr,
97-
code_ptr,
98-
absmax_ptr,
99-
out_ptr,
100-
n_elements,
101-
BLOCK_SIZE: tl.constexpr,
102-
CODE_SIZE: tl.constexpr,
103-
SPLIT_NUM_BLOCKS: tl.constexpr,
104-
):
105-
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
106-
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
107-
108-
offsets = block_start_idx * BLOCK_SIZE + thread_idx
109-
mask = offsets < n_elements
110-
111-
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
112-
113-
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
114-
A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))
115-
116-
# Calculating absamax for each block
117-
absmax = tl.max(tl.abs(A_reshaped), axis=1)
118-
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
119-
120-
A_normalized = A_reshaped / absmax[:, None]
121-
A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
122-
123-
lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
124-
upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
125-
126-
for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
127-
pivot = (lower_pivot + upper_pivot) // 2
128-
val = tl.load(code_ptr + pivot)
129-
is_higher = A_normalized > val # code[pivot]
130-
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
131-
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
132-
133-
# Choose closest level
134-
lower_val = tl.load(code_ptr + lower_pivot)
135-
upper_val = tl.load(code_ptr + upper_pivot)
136-
lower_dist = tl.abs(A_normalized - lower_val)
137-
upper_dist = tl.abs(A_normalized - upper_val)
138-
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
139-
140-
# too slow approach
141-
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
142-
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
143-
144-
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
145-
tl.store(out_ptr + offsets, quantized_flat, mask=mask)
146-
147-
148-
def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
149-
n = A.numel()
150-
151-
split_num_blocks = 1
152-
grid = (triton.cdiv(blocks, split_num_blocks),)
153-
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
154-
quantize_blockwise_kernel[grid](
155-
A_ptr=A,
156-
code_ptr=code,
157-
absmax_ptr=absmax,
158-
out_ptr=quantized_out,
159-
n_elements=n,
160-
BLOCK_SIZE=blocksize,
161-
CODE_SIZE=code.numel(),
162-
SPLIT_NUM_BLOCKS=split_num_blocks,
163-
)
164-
165-
return quantized_out, absmax
166-
167-
1687
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
1698
# @triton.autotune(
1709
# configs=[
@@ -587,7 +426,7 @@ def dequant_nf4_kernel(
587426
tl.store(c_ptr + offs, out_dq, mask)
588427

589428

590-
def _dequantize_4bit_impl(
429+
def dequantize_4bit_impl(
591430
A: torch.Tensor,
592431
absmax: torch.Tensor,
593432
blocksize: int,
@@ -611,7 +450,7 @@ def _dequantize_4bit_impl(
611450
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
612451

613452

614-
def _dequantize_4bit_impl_passing_code(
453+
def dequantize_4bit_impl_passing_code(
615454
A: torch.Tensor,
616455
absmax: torch.Tensor,
617456
blocksize: int,
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
# @triton.autotune(
8+
# configs=[
9+
# # triton.Config({'SPLIT_SIZE': 64}),
10+
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
11+
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
12+
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
13+
# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
14+
# # triton.Config({'SPLIT_SIZE': 128}),
15+
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
16+
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
17+
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
18+
# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
19+
# triton.Config({"SPLIT_SIZE": 256}),
20+
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
21+
# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
22+
# triton.Config({"SPLIT_SIZE": 512}),
23+
# # triton.Config({'SPLIT_SIZE': 1024}),
24+
# ],
25+
# key=["num_paired_elements", "QUANT_BLOCK"],
26+
# )
27+
@triton.jit
28+
def dequant_8bit_kernel(
29+
a_ptr,
30+
out_ptr,
31+
code_ptr,
32+
absmax_ptr,
33+
n,
34+
QUANT_BLOCK: tl.constexpr,
35+
SPLIT_SIZE: tl.constexpr,
36+
):
37+
pid = tl.program_id(axis=0)
38+
block_start = pid * SPLIT_SIZE
39+
offsets = block_start + tl.arange(0, SPLIT_SIZE)
40+
mask = offsets < n
41+
out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
42+
tl.store(out_ptr + offsets, out_dq, mask)
43+
44+
45+
def dequant_8bit_blockwise(
46+
a: torch.Tensor,
47+
absmax: torch.Tensor,
48+
quant_state_code: torch.Tensor,
49+
quant_blocksize: int = 64,
50+
dtype: torch.dtype = None,
51+
out: torch.Tensor = None,
52+
):
53+
n = a.numel()
54+
if out is None:
55+
if dtype is None:
56+
raise ValueError("If out is None, dtype must be specified")
57+
out = torch.empty_like(a, dtype=dtype, device=a.device)
58+
59+
SPLIT_SIZE = 256
60+
# grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
61+
grid = (triton.cdiv(n, SPLIT_SIZE),)
62+
dequant_8bit_kernel[grid](
63+
a,
64+
out,
65+
quant_state_code,
66+
absmax,
67+
n,
68+
quant_blocksize,
69+
SPLIT_SIZE,
70+
)
71+
return out
72+
73+
74+
# @triton.autotune(
75+
# configs=[
76+
# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
77+
# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
78+
# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
79+
# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
80+
# ],
81+
# key=["n_elements"],
82+
# )
83+
@triton.jit
84+
def quantize_8bit_blockwise_kernel(
85+
A_ptr,
86+
code_ptr,
87+
absmax_ptr,
88+
out_ptr,
89+
n_elements,
90+
BLOCK_SIZE: tl.constexpr,
91+
CODE_SIZE: tl.constexpr,
92+
SPLIT_NUM_BLOCKS: tl.constexpr,
93+
):
94+
block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
95+
thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
96+
97+
offsets = block_start_idx * BLOCK_SIZE + thread_idx
98+
mask = offsets < n_elements
99+
100+
A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
101+
102+
quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
103+
tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
104+
tl.store(out_ptr + offsets, quantized, mask=mask)
105+
106+
107+
def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
108+
n = A.numel()
109+
blocks = -(n // -blocksize)
110+
111+
if absmax is None:
112+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
113+
if out is None:
114+
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
115+
116+
split_num_blocks = 1
117+
grid = (triton.cdiv(blocks, split_num_blocks),)
118+
# grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
119+
quantize_8bit_blockwise_kernel[grid](
120+
A_ptr=A,
121+
code_ptr=code,
122+
absmax_ptr=absmax,
123+
out_ptr=out,
124+
n_elements=n,
125+
BLOCK_SIZE=blocksize,
126+
CODE_SIZE=code.numel(),
127+
SPLIT_NUM_BLOCKS=split_num_blocks,
128+
# num_warps=1,
129+
# num_stages=2,
130+
)
131+
out = out.reshape(A.shape)
132+
133+
return out, absmax
134+
135+
136+
@triton.jit
137+
def quantize_8bit_blockwise_kernel_util(
138+
a,
139+
code_ptr,
140+
CODE_SIZE: tl.constexpr,
141+
BLOCK_SIZE: tl.constexpr,
142+
N_PER_TH: tl.constexpr,
143+
):
144+
# To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
145+
a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))
146+
147+
# Calculating absmax for each block
148+
absmax = tl.max(tl.abs(a_reshaped), axis=1)
149+
150+
a_normalized = a_reshaped / absmax[:, None]
151+
a_normalized = tl.clamp(a_normalized, -1.0, 1.0)
152+
153+
lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32)
154+
upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
155+
156+
# ceil(log2(code_size)) = 8, actually, in general case should be input parameter
157+
for _ in range(8):
158+
pivot = (lower_pivot + upper_pivot) // 2
159+
val = tl.load(code_ptr + pivot)
160+
is_higher = a_normalized > val # code[pivot]
161+
lower_pivot = tl.where(is_higher, pivot, lower_pivot)
162+
upper_pivot = tl.where(is_higher, upper_pivot, pivot)
163+
164+
# Choose closest level
165+
lower_val = tl.load(code_ptr + lower_pivot)
166+
upper_val = tl.load(code_ptr + upper_pivot)
167+
lower_dist = tl.abs(a_normalized - lower_val)
168+
upper_dist = tl.abs(a_normalized - upper_val)
169+
quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
170+
171+
# too slow approach
172+
# diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
173+
# quantized = tl.argmin(diff, axis=2).to(tl.uint8)
174+
175+
quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
176+
return quantized_flat, absmax
177+
178+
179+
@triton.jit
180+
def dequant_8bit_blockwise_kernel_util(
181+
a_ptr,
182+
offsets,
183+
code_ptr,
184+
absmax_ptr,
185+
mask,
186+
BLOCK_SIZE: tl.constexpr,
187+
):
188+
a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
189+
scaled_int8 = tl.load(code_ptr + a, mask)
190+
# Load scales
191+
absmax_offsets = offsets // BLOCK_SIZE
192+
absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
193+
# Apply scales
194+
out_dq = scaled_int8 * absmax
195+
return out_dq

0 commit comments

Comments
 (0)