Skip to content

Commit 8f88cef

Browse files
Fix #1588 - torch compatability for <=2.4 (#1590)
1 parent d2fe0e3 commit 8f88cef

File tree

4 files changed

+22
-17
lines changed

4 files changed

+22
-17
lines changed

bitsandbytes/_ops.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# Higher level op: int8 matmul + dequant + bias
2020
torch.library.define(
2121
"bitsandbytes::int8_scaled_mm",
22-
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
22+
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType? dtype=None) -> Tensor",
2323
)
2424

2525

@@ -30,10 +30,10 @@ def _(
3030
row_stats: torch.Tensor,
3131
col_stats: torch.Tensor,
3232
bias: Optional[torch.Tensor] = None,
33-
dtype=torch.float16,
33+
dtype: Optional[torch.dtype] = None,
3434
) -> torch.Tensor:
3535
shapeC = (*A.shape[:-1], B.shape[0])
36-
return torch.empty(shapeC, device=A.device, dtype=dtype)
36+
return torch.empty(shapeC, device=A.device, dtype=dtype or torch.float16)
3737

3838

3939
torch.library.define(
@@ -98,15 +98,15 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
9898

9999

100100
# Default PyTorch-native implementation
101-
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
101+
@register_kernel("bitsandbytes::int8_vectorwise_dequant", "default")
102102
def _(A: torch.Tensor, stats: torch.Tensor):
103103
# To dequantize we divide by 127, or multiply by the reciprocal.
104104
return A * stats.view(-1, 1) * 7.874015718698502e-3
105105

106106

107107
torch.library.define(
108108
"bitsandbytes::int8_mm_dequant",
109-
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
109+
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType? dtype=None, Tensor? bias=None) -> Tensor",
110110
)
111111

112112

@@ -115,11 +115,11 @@ def _(
115115
A: torch.Tensor,
116116
row_stats: torch.Tensor,
117117
col_stats: torch.Tensor,
118-
dtype=torch.float16,
118+
dtype: Optional[torch.dtype] = None,
119119
bias: Optional[torch.Tensor] = None,
120120
) -> torch.Tensor:
121121
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
122-
return torch.empty_like(A, dtype=dtype)
122+
return torch.empty_like(A, dtype=dtype or torch.float16)
123123

124124

125125
torch.library.define(

bitsandbytes/backends/cpu/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _(
2828
A: torch.Tensor,
2929
row_stats: torch.Tensor,
3030
col_stats: torch.Tensor,
31-
dtype=torch.float16,
31+
dtype: Optional[torch.dtype] = None,
3232
bias: Optional[torch.Tensor] = None,
3333
) -> torch.Tensor:
3434
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
@@ -43,7 +43,7 @@ def _(
4343
if bias is not None:
4444
out += bias
4545

46-
return out.to(dtype)
46+
return out.to(dtype or torch.float16)
4747

4848

4949
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")

bitsandbytes/backends/cuda/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _(
9090
A: torch.Tensor,
9191
row_stats: torch.Tensor,
9292
col_stats: torch.Tensor,
93-
dtype=torch.float16,
93+
dtype: Optional[torch.dtype] = None,
9494
bias: Optional[torch.Tensor] = None,
9595
) -> torch.Tensor:
9696
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
@@ -121,7 +121,7 @@ def _(
121121
if bias is not None and bias.dtype != torch.float16:
122122
out.add_(bias)
123123

124-
return out.to(dtype)
124+
return out.to(dtype or torch.float16)
125125

126126

127127
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")

bitsandbytes/backends/default/ops.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,31 @@
55
from ..._ops import register_kernel
66

77

8-
@register_kernel("bitsandbytes::int8_scaled_mm", None)
8+
@register_kernel("bitsandbytes::int8_scaled_mm", "default")
99
def _(
1010
A: torch.Tensor,
1111
B: torch.Tensor,
1212
row_stats: torch.Tensor,
1313
col_stats: torch.Tensor,
1414
bias: Optional[torch.Tensor] = None,
15-
dtype=torch.float16,
15+
dtype: Optional[torch.dtype] = None,
1616
) -> torch.Tensor:
1717
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
18-
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
19-
return out
18+
return torch.ops.bitsandbytes.int8_mm_dequant.default(
19+
out_i32,
20+
row_stats,
21+
col_stats,
22+
dtype=dtype or torch.float16,
23+
bias=bias,
24+
)
2025

2126

22-
@register_kernel("bitsandbytes::int8_linear_matmul", None)
27+
@register_kernel("bitsandbytes::int8_linear_matmul", "default")
2328
def _(A: torch.Tensor, B: torch.Tensor):
2429
return _int8_linear_matmul_impl(A, B)
2530

2631

27-
@register_kernel("bitsandbytes::int8_linear_matmul.out", None)
32+
@register_kernel("bitsandbytes::int8_linear_matmul.out", "default")
2833
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
2934
torch._check(out.dtype == torch.int32)
3035
_int8_linear_matmul_impl(A, B, out)

0 commit comments

Comments
 (0)