Skip to content

Commit 15c8c48

Browse files
Support LLM.int8() inference with torch.compile
1 parent d2fe0e3 commit 15c8c48

File tree

3 files changed

+84
-25
lines changed

3 files changed

+84
-25
lines changed

bitsandbytes/_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,32 @@
1515
register_fake = torch.library.impl_abstract
1616
register_kernel = torch.library.impl
1717

18+
# Int8 mixed precision matmul + dequant + bias
19+
torch.library.define(
20+
"bitsandbytes::int8_mixed_scaled_mm",
21+
"(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)",
22+
)
23+
24+
25+
@register_fake("bitsandbytes::int8_mixed_scaled_mm")
26+
def _(
27+
A: torch.Tensor,
28+
CA: torch.Tensor,
29+
CB: torch.Tensor,
30+
SCA: torch.Tensor,
31+
SCB: torch.Tensor,
32+
outlier_cols: Optional[torch.Tensor] = None,
33+
bias: Optional[torch.Tensor] = None,
34+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
35+
shapeC = (*CA.shape[:-1], CB.shape[0])
36+
37+
out = torch.empty(shapeC, device=A.device, dtype=A.dtype)
38+
39+
outlier_cols = torch.library.get_ctx().new_dynamic_size()
40+
subA = A.new_empty(outlier_cols, dtype=torch.int64)
41+
42+
return out, subA
43+
1844

1945
# Higher level op: int8 matmul + dequant + bias
2046
torch.library.define(

bitsandbytes/autograd/_functions.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -210,37 +210,28 @@ def forward(
210210
# 2. Quantize B
211211
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
212212

213-
# Handle sparse decomposition. In some instances, we may have not found any
214-
# outlier columns at all. In that case, we'll skip this part completely.
215-
if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel():
213+
# Handle sparse decomposition
214+
if state.threshold > 0.0:
216215
state.idx = outlier_cols
217216

218-
# Zero out the outliers in the transposed 8bit inputs.
219-
if CAt is not None:
220-
CAt[:, state.idx] = 0
221-
222-
# Extract the input outliers in original precision
223-
subA = A[:, state.idx].contiguous()
217+
# Mixed Int8 Matmul + Dequant + Bias
218+
output, subA = torch.ops.bitsandbytes.int8_mixed_scaled_mm(
219+
A,
220+
CA,
221+
state.CB,
222+
SCA,
223+
state.SCB,
224+
outlier_cols,
225+
bias,
226+
)
224227

225-
# Extract the corresponding weights
226-
if state.has_fp16_weights:
227-
state.subB = B[:, state.idx].t()
228-
else:
229-
# To dequantize our weights associated with the input outliers,
230-
# we want to divide by 127. It's however more performant to multiply
231-
# by the reciprocal.
232-
outliers = state.CB[:, state.idx]
233-
state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
234228
else:
229+
# Int8 Matmul + Dequant + Bias
230+
output = torch.ops.bitsandbytes.int8_scaled_mm.default(
231+
CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype
232+
)
235233
subA = None
236234

237-
# 3. Int8 Matmul + Dequant + Bias
238-
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
239-
240-
# 4. Mixed-precision decomposition matmul
241-
if subA is not None and state.subB is not None:
242-
output = output.addmm(subA, state.subB)
243-
244235
# 5. Save state
245236
ctx.state = state
246237

bitsandbytes/backends/cuda/ops.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2222
_int8_linear_matmul_impl(A, B, out)
2323

2424

25+
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda")
26+
def _(
27+
A: torch.Tensor,
28+
CA: torch.Tensor,
29+
CB: torch.Tensor,
30+
SCA: torch.Tensor,
31+
SCB: torch.Tensor,
32+
outlier_cols: Optional[torch.Tensor] = None,
33+
bias: Optional[torch.Tensor] = None,
34+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
35+
subB = None
36+
37+
if outlier_cols is not None and outlier_cols.numel():
38+
# Extract the inputs with outliers in original precision
39+
subA = A[:, outlier_cols].contiguous()
40+
41+
# Dequantize the corresponding weight columns
42+
subB = (
43+
torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB)
44+
.to(A.dtype)
45+
.t()
46+
)
47+
48+
# TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t()
49+
50+
else:
51+
# Needed for torch.compile when there are no outliers.
52+
subA = torch.empty(0, device=A.device, dtype=A.dtype)
53+
54+
# Int8 Matmul + Dequant + Bias
55+
output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype)
56+
57+
if subB is not None:
58+
# Add the outlier columns back to the output
59+
output = output.addmm(subA, subB)
60+
61+
return output, subA
62+
63+
2564
def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2665
A, B = B, A
2766

@@ -143,6 +182,9 @@ def _(A: torch.Tensor, threshold=0.0):
143182

144183
if outliers.any():
145184
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
185+
else:
186+
# Needed for torch.compile support.
187+
outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64)
146188

147189
with _cuda_device_of(A):
148190
lib.cint8_vector_quant(

0 commit comments

Comments
 (0)