Skip to content

Commit 8c0271d

Browse files
committed
Merge branch 'multi-backend-refactor' into xpu
2 parents b0982fe + d3658c5 commit 8c0271d

File tree

8 files changed

+134
-44
lines changed

8 files changed

+134
-44
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import argparse
2+
3+
import torch
4+
import torch.utils.benchmark as benchmark
5+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6+
7+
parser = argparse.ArgumentParser()
8+
9+
parser.add_argument(
10+
"--model_name", default="meta-llama/Llama-3.1-8B-Instruct", required=False, type=str, help="model_name"
11+
)
12+
parser.add_argument("--quant_type", default="int8", type=str, help="quant type", choices=["int8", "nf4", "fp4"])
13+
parser.add_argument("--device_map", default="cpu", type=str, help="device_map", choices=["cpu", "xpu", "cuda"])
14+
args = parser.parse_args()
15+
16+
model_name = args.model_name
17+
device_map = args.device_map
18+
if args.quant_type == "int8":
19+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
20+
else:
21+
quantization_config = BitsAndBytesConfig(
22+
load_in_4bit=True,
23+
bnb_4bit_quant_type=args.quant_type,
24+
bnb_4bit_use_double_quant=True,
25+
bnb_4bit_compute_dtype=torch.bfloat16,
26+
)
27+
quantized_model = AutoModelForCausalLM.from_pretrained(
28+
model_name, torch_dtype="auto", device_map=device_map, quantization_config=quantization_config
29+
)
30+
tokenizer = AutoTokenizer.from_pretrained(model_name)
31+
input_text = "What are we having for dinner?"
32+
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device)
33+
34+
output = quantized_model.generate(**input_ids, max_new_tokens=10)
35+
print(tokenizer.decode(output[0], skip_special_tokens=True))
36+
37+
38+
# benchmark the performance
39+
def benchmark_fn(f, *args, **kwargs):
40+
# Manual warmup
41+
for _ in range(2):
42+
f(*args, **kwargs)
43+
44+
t0 = benchmark.Timer(
45+
stmt="f(*args, **kwargs)",
46+
globals={"args": args, "kwargs": kwargs, "f": f},
47+
num_threads=torch.get_num_threads(),
48+
)
49+
return t0.blocked_autorange().mean
50+
51+
52+
MAX_NEW_TOKENS = 100
53+
54+
quantized_model_latency = benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)
55+
56+
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, torch_dtype=torch.bfloat16)
57+
bf16_model_latency = benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS)
58+
59+
print(f"bnb model latency: {quantized_model_latency:.3f}")
60+
print(f"bf16 model latency: {bf16_model_latency:.3f}")
61+
print(f"BNB vs. bf16 model speed-up: {(bf16_model_latency / quantized_model_latency):.3f}")
62+
63+
print(f"BNB model memory: {(quantized_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
64+
print(f"bf16 model memory: {(bf16_model.get_memory_footprint() / 1024 / 1024 / 1024):.3f} GB")
65+
print(
66+
f"BNB vs. bf16 model memory ratio: {(bf16_model.get_memory_footprint() / quantized_model.get_memory_footprint()):.3f}"
67+
)

bitsandbytes/backends/cpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,7 @@ def quantize_4bit(
140140
if blocksize is None:
141141
blocksize = 64
142142
assert_on_cpu([A, absmax, out])
143-
assert quant_storage == torch.uint8, "CPU backend only supports uint8 quant_storage"
144-
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
143+
return quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type, quant_storage)
145144

146145
def dequantize_4bit(
147146
self,

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,10 @@ def int8_linear_matmul_impl(
194194

195195
A_reshaped = A.reshape(m, k)
196196

197-
# torch._int_mm is available on CPU since torch 2.4
198-
if _torch_version_prereq(2, 4) and A.device.type == "cpu":
197+
# torch._int_mm is available on CPU since torch 2.4, XPU since torch 2.6
198+
if (A.device.type == "cpu" and _torch_version_prereq(2, 4)) or (
199+
A.device.type == "xpu" and _torch_version_prereq(2, 6)
200+
):
199201
C = torch._int_mm(A_reshaped, B.T).to(dtype)
200202
else:
201203
C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype)
@@ -296,6 +298,7 @@ def quantize_4bit_impl(
296298
blocksize=64,
297299
compress_statistics=False,
298300
quant_type="nf4",
301+
quant_storage=torch.uint8,
299302
) -> Tensor:
300303
"""
301304
Quantize tensor A in blocks of 4-bit values.
@@ -314,6 +317,8 @@ def quantize_4bit_impl(
314317
The blocksize used in quantization.
315318
quant_type : str
316319
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now
320+
quant_storage: torch.dtype
321+
We can use bytes to convert storage type.
317322
318323
Returns
319324
-------
@@ -401,6 +406,10 @@ def quantize_4bit_impl(
401406
quant_type=quant_type,
402407
)
403408

409+
if quant_storage != torch.uint8:
410+
bytes_value = out.cpu().numpy().tobytes()
411+
out = torch.frombuffer(bytes_value, dtype=quant_storage).to(A.device)
412+
404413
return out.reshape(-1, 1), state
405414

406415

@@ -418,7 +427,8 @@ def dequant_8bit(A, offset, quant_state):
418427
return absmax
419428

420429

421-
@_maybe_torch_compile
430+
# Compile will fail in torch.frombuffer
431+
# @_maybe_torch_compile
422432
def dequantize_4bit_impl(
423433
A: Tensor,
424434
quant_state=None,
@@ -428,8 +438,7 @@ def dequantize_4bit_impl(
428438
quant_type="nf4",
429439
) -> Tensor:
430440
"""
431-
Dequantizes FP4 blockwise quantized values.
432-
441+
Dequantizes 4-bit blockwise quantized values.
433442
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
434443
435444
Parameters
@@ -445,8 +454,7 @@ def dequantize_4bit_impl(
445454
blocksize : int
446455
The blocksize used in quantization.
447456
quant_type : str
448-
The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now
449-
457+
The 4-bit quantization data type {fp4, nf4}
450458
451459
Returns
452460
-------
@@ -455,6 +463,10 @@ def dequantize_4bit_impl(
455463
"""
456464
transpose = True if A.shape[0] == 1 else False
457465
A = A.reshape(-1)
466+
device = A.device
467+
if A.dtype != torch.uint8:
468+
bytes_value = A.cpu().numpy().tobytes()
469+
A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device)
458470

459471
if quant_state is None:
460472
assert absmax is not None and out is not None

bitsandbytes/backends/xpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ def quantize_4bit(
158158
if blocksize is None:
159159
blocksize = 64
160160
assert_on_xpu([A, absmax, out])
161-
assert quant_storage == torch.uint8, "XPU backend only supports uint8 quant_storage"
162-
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type)
161+
output = quantize_4bit_impl(A, absmax, out, blocksize, compress_statistics, quant_type, quant_storage)
163162
return output
164163

165164
def dequantize_4bit(

bitsandbytes/functional.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def dequantize_fp4(
10761076
quant_state: Optional[QuantState] = None,
10771077
absmax: Optional[torch.Tensor] = None,
10781078
out: Optional[torch.Tensor] = None,
1079-
blocksize: int = 64,
1079+
blocksize: Optional[int] = None,
10801080
) -> torch.Tensor:
10811081
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
10821082

@@ -1086,7 +1086,7 @@ def dequantize_nf4(
10861086
quant_state: Optional[QuantState] = None,
10871087
absmax: Optional[torch.Tensor] = None,
10881088
out: Optional[torch.Tensor] = None,
1089-
blocksize: int = 64,
1089+
blocksize: Optional[int] = None,
10901090
) -> torch.Tensor:
10911091
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
10921092

@@ -1096,8 +1096,8 @@ def dequantize_4bit(
10961096
quant_state: Optional[QuantState] = None,
10971097
absmax: Optional[torch.Tensor] = None,
10981098
out: Optional[torch.Tensor] = None,
1099-
blocksize: int = 64,
1100-
quant_type="fp4",
1099+
blocksize: Optional[int] = None,
1100+
quant_type: Optional[str] = "fp4",
11011101
) -> torch.Tensor:
11021102
"""Dequantizes a packed 4-bit quantized tensor.
11031103
@@ -1115,9 +1115,9 @@ def dequantize_4bit(
11151115
Required if `quant_state` is not provided and ignored otherwise.
11161116
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
11171117
blocksize (`int`, *optional*):
1118-
The size of the blocks. Defaults to 64.
1118+
The size of the blocks. Defaults to 64 if not HIP_ENVIRONMENT else 128.
11191119
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
1120-
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
1120+
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to "fp4".
11211121
11221122
Raises:
11231123
ValueError: Raised when the input data type or blocksize is not supported.
@@ -1127,9 +1127,9 @@ def dequantize_4bit(
11271127
"""
11281128
ensure_backend_is_available(A.device.type)
11291129
if quant_state is not None:
1130-
absmax = absmax or quant_state.absmax
1131-
quant_type = quant_type or quant_state.quant_type
1132-
blocksize = blocksize or quant_state.blocksize
1130+
absmax = quant_state.absmax
1131+
quant_type = quant_state.quant_type
1132+
blocksize = quant_state.blocksize
11331133
if blocksize is None:
11341134
# Some AMD GPUs have warpsize 64
11351135
# Set default blocksize to 128 (~warpsize 64 in kernel) for HIP

bitsandbytes/nn/modules.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
487487
self.weight.data = reverse_4bit_compress_format(self.weight.data.reshape(1, -1))
488488

489489
self.weight.quant_state.ipex = False
490+
self.ipex_linear_is_set = False
490491

491492
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
492493

@@ -496,14 +497,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
496497

497498
def set_ipex_linear(self, x: torch.Tensor):
498499
if (
499-
(x.device.type in ("cpu", "xpu"))
500-
and not getattr(self.weight.quant_state, "ipex", False)
500+
not getattr(self.weight.quant_state, "ipex", False)
501+
and self.weight.data.dtype == torch.uint8
501502
and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
502503
and self.weight.quant_state.quant_type == "nf4"
503-
and not self.training
504-
and x.requires_grad == False
505504
):
506-
enable_ipex_fusion(self, x)
505+
if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False):
506+
enable_ipex_fusion(self, x)
507507

508508
def forward(self, x: torch.Tensor):
509509
# Check if ipex fusion can be used
@@ -700,26 +700,24 @@ def to(self, *args, **kwargs):
700700
elif device.type == "cpu":
701701
if self.data.dtype == torch.int8:
702702
self.CB = self.data
703-
return self
704703
else:
705704
return self.cpu()
706705
elif device.type == "xpu":
707706
if self.data.dtype == torch.int8:
708-
self.data = self.data.contiguous().xpu(device)
707+
self.data = self.data.contiguous()
709708
self.CB = self.data
710-
return self
711-
else:
709+
if self.data.device.type == "cpu":
712710
return self.xpu(device)
713-
else:
714-
new_param = Int8Params(
715-
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
716-
requires_grad=self.requires_grad,
717-
has_fp16_weights=self.has_fp16_weights,
718-
)
719-
new_param.CB = self.CB
720-
new_param.SCB = self.SCB
721711

722-
return new_param
712+
new_param = Int8Params(
713+
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
714+
requires_grad=self.requires_grad,
715+
has_fp16_weights=self.has_fp16_weights,
716+
)
717+
new_param.CB = self.CB
718+
new_param.SCB = self.SCB
719+
720+
return new_param
723721

724722

725723
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):

docs/source/non_cuda_backends.mdx

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,25 @@ Thank you for your support!
2727

2828
### Intel
2929

30-
The following performance data is collected from Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf).
30+
The below performance data is collected from the Intel 4th Gen Xeon (SPR) platform. The tables show speed-up and memory compared with different data types of [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct).
31+
32+
You may run `benchmarking/generation_benchmark.py` to reproduce the below model memory and inference results. Please note that you need to bind cores if you are using the CPU to benchmark. For example, run `numactl -C 0-55 -m 0 python generation_benchmark.py --quant_type nf4` on Intel 4th Gen Xeon with single socket.
33+
34+
The finetune results are selected from [peft](https://github.com/huggingface/peft/blob/main/examples/olora_finetuning/olora_finetuning.py).
35+
36+
#### Model memory (CPU)
37+
| Data Type | BF16 | INT8 | NF4 | FP4 |
38+
|---|---|---|---|---|
39+
| Memory (GB) | 15.0 | 8.5 | 5.2 | 5.2 |
3140

3241
#### Inference (CPU)
3342

3443
| Data Type | BF16 | INT8 | NF4 | FP4 |
3544
|---|---|---|---|---|
36-
| Speed-Up (vs BF16) | 1.0x | 0.44x | 1.8x | 0.1x |
37-
| Memory (GB) | 13.1 | 7.6 | 5.0 | 4.6 |
45+
| Speed-Up (vs BF16) | 1.0x | 0.57x | 2.6x | 0.1x |
3846

3947
#### Fine-Tuning (CPU)
4048

4149
| Data Type | BF16 | INT8 | NF4 | FP4 |
4250
|---|---|---|---|---|
43-
| Speed-Up (vs BF16) | 1.0x | 0.38x | 0.1x | 0.1x |
44-
| Memory (GB) | 40 | 9 | 6.6 | 6.6 |
51+
| Speed-Up (vs BF16) | 1.0x | 0.91x | 1.0x | 1.0x |

setup.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from setuptools import find_packages, setup
22
from setuptools.dist import Distribution
33

4+
VERSION = "1.0.0"
5+
46

57
# Tested with wheel v0.45.1
68
class BinaryDistribution(Distribution):
79
def has_ext_modules(self):
810
return True
911

1012

11-
setup(packages=find_packages(), distclass=BinaryDistribution)
13+
def write_version_file(version, filepath="bitsandbytes/_version.py"):
14+
with open(filepath, "w") as f:
15+
f.write(f'__version__ = "{version}"\n')
16+
return version
17+
18+
19+
setup(packages=find_packages(), distclass=BinaryDistribution, version=write_version_file(VERSION))

0 commit comments

Comments
 (0)