Skip to content

Commit 683f37c

Browse files
authored
Merge pull request #10 from xiaolil1/jiqing
Remove ipex entirely
2 parents aa0cf92 + 005a63c commit 683f37c

File tree

9 files changed

+94
-276
lines changed

9 files changed

+94
-276
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ jobs:
162162
- name: Run tests
163163
run: pytest --durations=100
164164

165-
test-cpu-ipex:
165+
test-cpu-intel:
166166
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
167167
needs: build-cpu
168168
runs-on: banb-aws-general-8-plus-use1-public-80
@@ -186,7 +186,6 @@ jobs:
186186
- name: Install dependencies
187187
run: |
188188
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
189-
pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
190189
pip install -e ".[test]"
191190
pip install pytest-cov
192191
@@ -196,9 +195,6 @@ jobs:
196195
- name: Show environment information
197196
run: python -m torch.utils.collect_env
198197

199-
- name: IPEX smoke test
200-
run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);"
201-
202198
- name: Run tests
203199
run: pytest --durations=100
204200

@@ -286,15 +282,6 @@ jobs:
286282
fail-fast: false
287283
matrix:
288284
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
289-
ipex: [false]
290-
# ipex: [true, false]
291-
# include:
292-
# - torch_version: "2.6.0"
293-
# ipex: true
294-
# ipex_version: "2.6.10+xpu"
295-
# - torch_version: "2.7.1"
296-
# ipex: true
297-
# ipex_version: "2.7.10+xpu"
298285
runs-on:
299286
group: bandb-itac-bmsprpvc1550-8-1gpu
300287
env:
@@ -330,10 +317,6 @@ jobs:
330317
- name: Install PyTorch
331318
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu
332319

333-
- name: Install IPEX
334-
if: matrix.ipex == true
335-
run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
336-
337320
- name: Install dependencies
338321
run: |
339322
pip install -e ".[test]"

bitsandbytes/_ops.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import torch
66

7-
from .utils import ipex_cpu
8-
97
_IS_TORCH_GTE_24 = False
108

119
if hasattr(torch.library, "register_fake"):
@@ -329,22 +327,3 @@ def _(
329327
)
330328
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
331329
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
332-
333-
334-
if ipex_cpu:
335-
# Register the dequantize_nf4_ipex implementation
336-
torch.library.define(
337-
"bitsandbytes::dequantize_nf4_ipex",
338-
"(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
339-
)
340-
341-
@register_fake("bitsandbytes::dequantize_nf4_ipex")
342-
def _(
343-
A: torch.Tensor,
344-
absmax: torch.Tensor,
345-
blocksize: int,
346-
shape: Sequence[int],
347-
dtype: torch.dtype,
348-
) -> torch.Tensor:
349-
torch._check_is_size(blocksize)
350-
return torch.empty(shape, dtype=dtype, device=A.device)

bitsandbytes/autograd/_functions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,9 @@ def matmul(
422422
if threshold > 0.0:
423423
state.threshold = threshold
424424
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
425-
if state.is_training:
426-
if A.device.type in ("cpu", "xpu"):
427-
return MatMul8bitFp.apply(A, B, out, bias, state)
425+
if state.is_training and A.device.type in ("cpu", "xpu"):
426+
return MatMul8bitFp.apply(A, B, out, bias, state)
427+
428428
return MatMul8bitLt.apply(A, B, out, bias, state)
429429

430430

@@ -437,16 +437,6 @@ def matmul_4bit(
437437
):
438438
assert quant_state is not None
439439

440-
if A.device.type == "cpu" and A.requires_grad == False:
441-
if getattr(quant_state, "ipex", False):
442-
# IPEX CPU will change weight to 4D so don't need transpose
443-
B = B.t() if B.dim() == 2 else B
444-
out = F.gemv_4bit(A, B, out, state=quant_state)
445-
if bias is not None:
446-
out += bias
447-
return out
448-
else:
449-
return MatMul4Bit.apply(A, B, out, bias, quant_state)
450440
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
451441
if A.shape[-1] % quant_state.blocksize != 0:
452442
warn(

bitsandbytes/backends/cpu/ops.py

Lines changed: 79 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from collections.abc import Sequence
21
import ctypes as ct
2+
import logging
33

44
import torch
55

66
from bitsandbytes.functional import get_ptr
77

88
from ..._ops import register_kernel
9-
from ...cextension import lib
10-
from ...utils import ipex_cpu
9+
from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
10+
11+
logger = logging.getLogger(__name__)
1112

1213
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
1314
# However, we can overflow if we use this without AVX512_VNNI support.
@@ -24,97 +25,80 @@ def _(A: torch.Tensor, B: torch.Tensor):
2425
).reshape(*A.shape[:-1], B.shape[0])
2526

2627

27-
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
28-
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
29-
torch._check_is_size(blocksize)
30-
31-
n = A.numel()
32-
33-
# Only FP32 has c++ kernrl
34-
if A.dtype == torch.float32:
35-
blocks = -(n // -blocksize)
36-
37-
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
38-
out = torch.empty_like(A, dtype=torch.uint8)
39-
40-
lib.cquantize_blockwise_cpu_fp32(
41-
get_ptr(code),
42-
get_ptr(A),
43-
get_ptr(absmax),
44-
get_ptr(out),
45-
ct.c_longlong(blocksize),
46-
ct.c_longlong(n),
47-
)
48-
else:
49-
rem = n % blocksize
50-
has_rem = rem > 0
51-
blocks = n // blocksize + has_rem
52-
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
53-
A_reshaped = A.reshape(n)
54-
A_com = A_reshaped[: n - rem]
55-
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
56-
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
57-
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
58-
scaled_A = scaled_A.reshape(-1)
59-
if has_rem:
60-
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
61-
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
62-
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
63-
64-
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
65-
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
66-
67-
return out, absmax
68-
69-
70-
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
71-
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
72-
torch._check_is_size(blocksize)
73-
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
74-
75-
# Only FP32 has c++ kernrl
76-
if dtype == torch.float32:
77-
out = torch.empty_like(A, dtype=dtype)
78-
79-
lib.cdequantize_blockwise_cpu_fp32(
80-
get_ptr(code),
81-
get_ptr(A),
82-
get_ptr(absmax),
83-
get_ptr(out),
84-
ct.c_longlong(blocksize),
85-
ct.c_longlong(A.numel()),
86-
)
87-
else:
88-
out = code[A.reshape(-1).int()]
89-
blocks = out.shape[-1] // blocksize
90-
res = out.shape[-1] % blocksize
91-
if res != 0:
92-
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
93-
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
94-
out = out[: blocks * blocksize + res]
95-
out = out.reshape(A.shape)
96-
97-
return out
98-
99-
100-
if ipex_cpu:
101-
from bitsandbytes.utils import _reverse_4bit_compress_format
102-
103-
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
28+
if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
29+
logger.info("Loading C++ bitsandbytes kernels for CPU")
30+
31+
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
32+
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
33+
torch._check_is_size(blocksize)
34+
35+
n = A.numel()
36+
37+
# Only FP32 has c++ kernrl
38+
if A.dtype == torch.float32:
39+
blocks = -(n // -blocksize)
40+
41+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
42+
out = torch.empty_like(A, dtype=torch.uint8)
43+
44+
lib.cquantize_blockwise_cpu_fp32(
45+
get_ptr(code),
46+
get_ptr(A),
47+
get_ptr(absmax),
48+
get_ptr(out),
49+
ct.c_longlong(blocksize),
50+
ct.c_longlong(n),
51+
)
52+
else:
53+
rem = n % blocksize
54+
has_rem = rem > 0
55+
blocks = n // blocksize + has_rem
56+
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
57+
A_reshaped = A.reshape(n)
58+
A_com = A_reshaped[: n - rem]
59+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
60+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
61+
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
62+
scaled_A = scaled_A.reshape(-1)
63+
if has_rem:
64+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
65+
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
66+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
67+
68+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
69+
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
70+
71+
return out, absmax
72+
73+
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
10474
def _(
105-
A: torch.Tensor,
106-
absmax: torch.Tensor,
107-
blocksize: int,
108-
shape: Sequence[int],
109-
dtype: torch.dtype,
75+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
11076
) -> torch.Tensor:
111-
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
112-
A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
113-
return torch.ops.bitsandbytes.dequantize_4bit.default(
114-
A,
115-
absmax,
116-
blocksize,
117-
"nf4",
118-
shape,
119-
dtype,
120-
)
77+
torch._check_is_size(blocksize)
78+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
79+
80+
# Only FP32 has c++ kernrl
81+
if dtype == torch.float32:
82+
out = torch.empty_like(A, dtype=dtype)
83+
84+
lib.cdequantize_blockwise_cpu_fp32(
85+
get_ptr(code),
86+
get_ptr(A),
87+
get_ptr(absmax),
88+
get_ptr(out),
89+
ct.c_longlong(blocksize),
90+
ct.c_longlong(A.numel()),
91+
)
92+
else:
93+
out = code[A.reshape(-1).int()]
94+
blocks = out.shape[-1] // blocksize
95+
res = out.shape[-1] % blocksize
96+
if res != 0:
97+
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
98+
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
99+
out = out[: blocks * blocksize + res]
100+
out = out.reshape(A.shape)
101+
102+
return out
103+
else:
104+
logger.warning("Loading pytorch bitsandbytes kernels for CPU because no native library found.")

bitsandbytes/functional.py

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch import Tensor
1414
from typing_extensions import deprecated
1515

16-
from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
16+
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
1717

1818
from .cextension import HIP_ENVIRONMENT, lib
1919

@@ -1055,16 +1055,6 @@ def dequantize_4bit(
10551055
if absmax.dtype != torch.float32:
10561056
absmax = absmax.float()
10571057

1058-
# IPEX format is different, we need extra process.
1059-
if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
1060-
return torch.ops.bitsandbytes.dequantize_nf4_ipex(
1061-
A,
1062-
absmax,
1063-
quant_state.blocksize,
1064-
quant_state.shape,
1065-
quant_state.dtype,
1066-
)
1067-
10681058
if out is not None:
10691059
torch.ops.bitsandbytes.dequantize_4bit.out(
10701060
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
@@ -1633,25 +1623,6 @@ def gemv_4bit(
16331623
if state.nested:
16341624
absmax = dequantize_blockwise(absmax, state.state2) + state.offset
16351625

1636-
if getattr(state, "ipex", False) and state.quant_type == "nf4":
1637-
# compute_dtype: 1 indicates fp16, 2 indicates bf16
1638-
compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
1639-
out = torch.ops.torch_ipex.woq_linear(
1640-
A,
1641-
B,
1642-
"nf4",
1643-
state.shape,
1644-
state.new_scales,
1645-
state.new_zeros,
1646-
None,
1647-
None,
1648-
state.blocksize,
1649-
compute_dtype,
1650-
1,
1651-
state.compensation,
1652-
)
1653-
return out
1654-
16551626
if out is not None:
16561627
torch.ops.bitsandbytes.gemv_4bit.out(
16571628
A,
@@ -2338,37 +2309,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
23382309

23392310

23402311
C = 127.0
2341-
2342-
2343-
def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
2344-
quant_state = linear.weight.quant_state
2345-
2346-
if quant_state.nested:
2347-
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
2348-
absmax += quant_state.offset
2349-
if absmax.dtype != torch.float32:
2350-
absmax = absmax.float()
2351-
2352-
quant_state.absmax = absmax
2353-
quant_state.nested = False
2354-
delattr(quant_state, "state2")
2355-
2356-
assert x.device.type == "cpu"
2357-
converted_weight = _reverse_4bit_compress_format(linear.weight.data)
2358-
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
2359-
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
2360-
"nf4",
2361-
quant_state.shape, # weight shape
2362-
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
2363-
None, # zero_points
2364-
None, # bias
2365-
None, # batch_size
2366-
quant_state.blocksize,
2367-
2,
2368-
)
2369-
2370-
linear.weight.data = new_weight.data
2371-
linear.weight.quant_state.ipex = True
2372-
linear.weight.quant_state.new_scales = new_scales
2373-
linear.weight.quant_state.new_zeros = new_zeros
2374-
linear.weight.quant_state.compensation = compensation

0 commit comments

Comments
 (0)