Skip to content

Commit 27d916b

Browse files
authored
Moved int8_mm_dequant to default backend (#1626)
1 parent 8b858e4 commit 27d916b

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

bitsandbytes/backends/cpu/ops.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from collections.abc import Sequence
22
import ctypes as ct
3-
from typing import Optional
43

54
import torch
65

@@ -24,29 +23,6 @@ def _(A: torch.Tensor, B: torch.Tensor):
2423
).reshape(*A.shape[:-1], B.shape[0])
2524

2625

27-
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu")
28-
def _(
29-
A: torch.Tensor,
30-
row_stats: torch.Tensor,
31-
col_stats: torch.Tensor,
32-
dtype: Optional[torch.dtype] = None,
33-
bias: Optional[torch.Tensor] = None,
34-
) -> torch.Tensor:
35-
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
36-
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
37-
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
38-
39-
A_calc = A.view(-1, A.shape[-1])
40-
row_stats = row_stats.reshape(-1).unsqueeze(-1)
41-
col_stats = col_stats.reshape(-1).unsqueeze(0)
42-
43-
out = A_calc * (row_stats * col_stats) * 6.200124e-05
44-
if bias is not None:
45-
out += bias
46-
47-
return out.to(dtype or torch.float16)
48-
49-
5026
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
5127
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
5228
torch._check_is_size(blocksize)

bitsandbytes/backends/default/ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,29 @@
66
from ..._ops import register_kernel
77

88

9+
@register_kernel("bitsandbytes::int8_mm_dequant", "default")
10+
def _(
11+
A: torch.Tensor,
12+
row_stats: torch.Tensor,
13+
col_stats: torch.Tensor,
14+
dtype: Optional[torch.dtype] = None,
15+
bias: Optional[torch.Tensor] = None,
16+
) -> torch.Tensor:
17+
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
18+
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
19+
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
20+
21+
A_calc = A.view(-1, A.shape[-1])
22+
row_stats = row_stats.reshape(-1).unsqueeze(-1)
23+
col_stats = col_stats.reshape(-1).unsqueeze(0)
24+
25+
out = A_calc * (row_stats * col_stats) * 6.200124e-05
26+
if bias is not None:
27+
out += bias
28+
29+
return out.to(dtype or torch.float16)
30+
31+
932
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
1033
def _(
1134
A: torch.Tensor,

0 commit comments

Comments
 (0)