Skip to content

Commit 2b85100

Browse files
Cleanup; rename int8_linear_dequant -> int8_scaled_mm
1 parent 25368bc commit 2b85100

File tree

6 files changed

+47
-23
lines changed

6 files changed

+47
-23
lines changed

bitsandbytes/_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
# Higher level op: int8 matmul + dequant + bias
1919
torch.library.define(
20-
"bitsandbytes::int8_linear_dequant",
20+
"bitsandbytes::int8_scaled_mm",
2121
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
2222
)
2323

2424

25-
@register_fake("bitsandbytes::int8_linear_dequant")
25+
@register_fake("bitsandbytes::int8_scaled_mm")
2626
def _(
2727
A: torch.Tensor,
2828
B: torch.Tensor,
@@ -35,7 +35,7 @@ def _(
3535
return torch.empty(shapeC, device=A.device, dtype=dtype)
3636

3737

38-
@register_kernel("bitsandbytes::int8_linear_dequant", None)
38+
@register_kernel("bitsandbytes::int8_scaled_mm", None)
3939
def _(
4040
A: torch.Tensor,
4141
B: torch.Tensor,

bitsandbytes/autograd/_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def forward(
355355
subA = None
356356

357357
# 3. Int8 Matmul + Dequant + Bias
358-
output = torch.ops.bitsandbytes.int8_linear_dequant(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
358+
output = torch.ops.bitsandbytes.int8_scaled_mm(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
359359

360360
# 4. Mixed-precision decomposition matmul
361361
if subA is not None and state.subB is not None:
@@ -405,7 +405,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
405405
if req_gradB:
406406
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
407407

408-
grad_B = torch.ops.bitsandbytes.int8_linear_dequant(
408+
grad_B = torch.ops.bitsandbytes.int8_scaled_mm(
409409
Cgrad.t().contiguous(),
410410
CAt.t(),
411411
SCgradt,

bitsandbytes/backends/cpu/ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[tor
2828
return result
2929

3030

31+
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu")
32+
def _(
33+
A: torch.Tensor,
34+
row_stats: torch.Tensor,
35+
col_stats: torch.Tensor,
36+
dtype=torch.float16,
37+
bias: Optional[torch.Tensor] = None,
38+
) -> torch.Tensor:
39+
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
40+
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
41+
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
42+
43+
A_calc = A.view(-1, A.shape[-1])
44+
row_stats = row_stats.reshape(-1).unsqueeze(-1)
45+
col_stats = col_stats.reshape(-1).unsqueeze(0)
46+
47+
out = A_calc * (row_stats * col_stats) * 6.200124e-05
48+
if bias is not None:
49+
out += bias
50+
51+
return out.to(dtype)
52+
53+
3154
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
3255
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
3356
torch._check_is_size(blocksize)

bitsandbytes/backends/cuda/ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def _(
348348

349349

350350
@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda")
351-
def _dequantize_4bit_impl(
351+
def _(
352352
A: torch.Tensor,
353353
absmax: torch.Tensor,
354354
blocksize: int,
@@ -358,7 +358,6 @@ def _dequantize_4bit_impl(
358358
out: torch.Tensor,
359359
) -> None:
360360
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
361-
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
362361
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
363362
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
364363

@@ -430,7 +429,6 @@ def _(
430429
out.shape == (*A.shape[:-1], shapeB[0]),
431430
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
432431
)
433-
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
434432
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
435433
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
436434

bitsandbytes/functional.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,12 +1540,21 @@ def optimizer_update_8bit_blockwise(
15401540

15411541
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
15421542

1543+
print(
1544+
f"{p.device} {g.device} {state1.device} {state2.device} {qmap1.device} {qmap2.device} {absmax1.device} {absmax2.device} \n\n"
1545+
f"{p.dtype} {g.dtype} {state1.dtype} {state2.dtype} {qmap1.dtype} {qmap2.dtype} {absmax1.dtype} {absmax2.dtype} \n\n"
1546+
f"{p.__class__} {g.__class__} {state1.__class__} {state2.__class__} {qmap1.__class__} {qmap2.__class__} {absmax1.__class__} {absmax2.__class__} \n\n"
1547+
f"{p.data_ptr()} {g.data_ptr()} {state1.data_ptr()} {state2.data_ptr()} {qmap1.data_ptr()} {qmap2.data_ptr()} {absmax1.data_ptr()} {absmax2.data_ptr()} \n\n"
1548+
)
1549+
1550+
print(p, g, state1, state2)
1551+
15431552
with _cuda_device_of(g):
15441553
optim_func(
1545-
get_ptr(p),
1546-
get_ptr(g),
1547-
get_ptr(state1),
1548-
get_ptr(state2),
1554+
get_ptr(p.to_local()),
1555+
get_ptr(g.to_local()),
1556+
get_ptr(state1.to_local()),
1557+
get_ptr(state2.to_local()),
15491558
ct.c_float(beta1),
15501559
ct.c_float(beta2),
15511560
ct.c_float(beta3),
@@ -1560,7 +1569,7 @@ def optimizer_update_8bit_blockwise(
15601569
ct.c_float(weight_decay),
15611570
ct.c_float(gnorm_scale),
15621571
ct.c_bool(skip_zeros),
1563-
ct.c_int32(g.numel()),
1572+
ct.c_int32(g.to_local().numel()),
15641573
)
15651574

15661575

tests/test_ops.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,6 @@ def test_int8_vectorwise_quant(self, threshold, device):
6666

6767
@pytest.mark.parametrize("device", ["cpu", "cuda"])
6868
def test_int8_mm_dequant(self, device):
69-
if device == "cpu":
70-
pytest.skip("CPU implementation is not available")
71-
7269
A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
7370
row_stats = torch.randn(256, dtype=torch.float32, device=device)
7471
col_stats = torch.randn(256, dtype=torch.float32, device=device)
@@ -83,22 +80,19 @@ def test_int8_mm_dequant(self, device):
8380
@pytest.mark.parametrize("device", ["cpu", "cuda"])
8481
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
8582
@pytest.mark.parametrize("has_bias", TRUE_FALSE)
86-
def test_int8_linear_dequant(self, device, dtype, has_bias):
87-
if device == "cpu":
88-
pytest.skip("CPU implementation is not available")
89-
83+
def test_int8_scaled_mm(self, device, dtype, has_bias):
9084
A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
9185
B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
9286
row_stats = torch.randn(10, dtype=torch.float32, device=device)
93-
col_stats = torch.randn(20, dtype=torch.float32, device=device)
87+
col_stats = torch.randn(30, dtype=torch.float32, device=device)
9488
bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
95-
out = torch.ops.bitsandbytes.int8_linear_dequant(A, B, row_stats, col_stats, bias=bias, dtype=dtype)
89+
out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)
9690

9791
assert out.shape == (10, 30)
9892
assert out.dtype == dtype
9993
assert out.device == A.device
10094

101-
torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_dequant, (A, B, row_stats, col_stats, bias, dtype))
95+
torch.library.opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))
10296

10397

10498
class TestInt8BlockwiseQuantOps:

0 commit comments

Comments
 (0)