Skip to content

Commit 314f724

Browse files
committed
fix format
1 parent 1e27a22 commit 314f724

File tree

6 files changed

+49
-38
lines changed

6 files changed

+49
-38
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def mm_dequant_impl(
234234
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
235235

236236
if compute_dtype not in [torch.float32, torch.bfloat16]:
237-
warnings.warn(f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead")
237+
warnings.warn(
238+
f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead"
239+
)
238240
compute_dtype = torch.bfloat16
239241
A_reshaped = A.reshape(out_shape).to(compute_dtype)
240242
row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype)
@@ -439,9 +441,7 @@ def dequantize_4bit_impl(
439441
raise NotImplementedError("bnb_4bit_use_double_quant is not supported yet for CPU/XPU")
440442

441443
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False):
442-
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(
443-
A, "nf4", quant_state.shape, 2
444-
)
444+
A = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
445445
quant_state.ipex = False
446446

447447
# Map nf4 to [-1, 1]
@@ -466,9 +466,9 @@ def dequantize_4bit_impl(
466466
if out is None:
467467
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
468468
out_reshaped = out.reshape(-1)
469-
out_reshaped[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(
470-
-1
471-
)
469+
out_reshaped[: n - rem] = (
470+
out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)
471+
).reshape(-1)
472472
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]
473473
else:
474474
out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype)
@@ -513,9 +513,20 @@ def gemm_4bit_impl(
513513
GEMM output tensor.
514514
"""
515515
if getattr(state, "ipex", False):
516-
output = torch.ops.torch_ipex.woq_linear(A, B, "nf4", state.shape,
517-
state.new_scales, state.new_zeros, None, None, state.blocksize,
518-
ipex_cpu.quantization.WoqLowpMode.BF16, 1, state.compensation)
516+
output = torch.ops.torch_ipex.woq_linear(
517+
A,
518+
B,
519+
"nf4",
520+
state.shape,
521+
state.new_scales,
522+
state.new_zeros,
523+
None,
524+
None,
525+
state.blocksize,
526+
ipex_cpu.quantization.WoqLowpMode.BF16,
527+
1,
528+
state.compensation,
529+
)
519530
else:
520531
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
521532
output = torch.matmul(A, dqB.to(A.dtype))

bitsandbytes/backends/xpu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
)
1616

1717
Tensor = torch.Tensor
18+
19+
1820
def assert_on_xpu(tensors):
1921
on_xpu = True
2022
for t in tensors:
@@ -124,7 +126,6 @@ def extract_outliers(
124126
output = A[:, idx].contiguous()
125127
return output
126128

127-
128129
def quantize_4bit(
129130
self,
130131
A: torch.Tensor,
@@ -155,7 +156,7 @@ def dequantize_4bit(
155156
blocksize = 64
156157
assert_on_xpu([A, absmax, out])
157158
if quant_type == "nf4":
158-
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None,blocksize).t()
159+
output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
159160
else:
160161
output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type)
161162

bitsandbytes/functional.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,6 @@ def dequantize_fp4(
10061006
out: Optional[torch.Tensor] = None,
10071007
blocksize: Optional[int] = None,
10081008
) -> Tensor:
1009-
10101009
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
10111010

10121011

@@ -1017,7 +1016,6 @@ def dequantize_nf4(
10171016
out: Optional[torch.Tensor] = None,
10181017
blocksize: Optional[int] = None,
10191018
) -> Tensor:
1020-
10211019
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
10221020

10231021

bitsandbytes/nn/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from ..backends import backends
56
from .modules import (
67
Embedding,
78
Int8Params,
@@ -14,7 +15,7 @@
1415
StableEmbedding,
1516
SwitchBackLinearBnb,
1617
)
17-
from ..backends import backends
18+
1819
# CPU and XPU backend do not need triton, and XPU so not support triton for now.
1920
if "xpu" not in backends.keys() or ("cpu" in backends.keys() and len(backends.keys()) == 1):
2021
from .triton_based_modules import (

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
449449
save weight and bias,
450450
then fill state_dict with components of quant_state
451451
"""
452-
if (
453-
getattr(self.weight, "quant_state", None) is not None
454-
and getattr(self.weight.quant_state, "ipex", False)
455-
):
452+
if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False):
456453
if self.weight.device.type == "cpu":
457454
original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
458455
self.weight, "nf4", self.weight.quant_state.shape, 2

bitsandbytes/utils.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -201,35 +201,38 @@ def unpack_tensor_to_dict(tensor_data):
201201

202202

203203
def enable_ipex_fusion(linear):
204-
from bitsandbytes.backends.cpu_xpu_common import _ipex_cpu_version_prereq, _ipex_xpu_version_prereq
205-
from bitsandbytes.backends.cpu_xpu_common import ipex_cpu_only, ipex_xpu
204+
from bitsandbytes.backends.cpu_xpu_common import (
205+
_ipex_cpu_version_prereq,
206+
_ipex_xpu_version_prereq,
207+
ipex_cpu_only,
208+
ipex_xpu,
209+
)
206210

207211
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5):
208212
quant_state = linear.weight.quant_state
209-
new_weight, new_scales, new_zeros, _, compensation = \
210-
torch.ops.ipex_prepack.woq_linear_pack_weight(
211-
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
212-
"nf4",
213-
quant_state.shape, # weight shape
214-
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
215-
None, # zero_points
216-
None, # bias
217-
None, # batch_size
218-
quant_state.blocksize,
219-
2,
220-
)
213+
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
214+
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
215+
"nf4",
216+
quant_state.shape, # weight shape
217+
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
218+
None, # zero_points
219+
None, # bias
220+
None, # batch_size
221+
quant_state.blocksize,
222+
2,
223+
)
221224
elif ipex_xpu and _ipex_xpu_version_prereq(2, 5):
222225
quant_state = linear.weight.quant_state
223226
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
224-
227+
225228
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
226229
new_zeros = None
227230
compensation = None
228231
linear.weight.data = new_weight.data
229-
setattr(linear.weight.quant_state, "ipex", True)
230-
setattr(linear.weight.quant_state, "new_scales", new_scales)
231-
setattr(linear.weight.quant_state, "new_zeros", new_zeros)
232-
setattr(linear.weight.quant_state, "compensation", compensation)
232+
linear.weight.quant_state.ipex = True
233+
linear.weight.quant_state.new_scales = new_scales
234+
linear.weight.quant_state.new_zeros = new_zeros
235+
linear.weight.quant_state.compensation = compensation
233236

234237

235238
class QuantState:

0 commit comments

Comments
 (0)