Skip to content

Commit 7e6f865

Browse files
authored
fix device check (#1453)
Signed-off-by: jiqing-feng <[email protected]>
1 parent 9948333 commit 7e6f865

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def set_ipex_linear(self, x: torch.Tensor):
481481
and not self.training
482482
and x.requires_grad == False
483483
):
484-
enable_ipex_fusion(self)
484+
enable_ipex_fusion(self, x)
485485

486486
def forward(self, x: torch.Tensor):
487487
# Check if ipex fusion can be used

bitsandbytes/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,15 @@ def unpack_tensor_to_dict(tensor_data):
200200
return unpacked_dict
201201

202202

203-
def enable_ipex_fusion(linear):
203+
def enable_ipex_fusion(linear, x):
204204
from bitsandbytes.backends.cpu_xpu_common import (
205205
_ipex_cpu_version_prereq,
206206
_ipex_xpu_version_prereq,
207-
ipex_cpu_only,
207+
ipex_cpu,
208208
ipex_xpu,
209209
)
210210

211-
if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5):
211+
if x.device.type == "cpu" and ipex_cpu and _ipex_cpu_version_prereq(2, 5):
212212
quant_state = linear.weight.quant_state
213213
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
214214
linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
@@ -221,7 +221,7 @@ def enable_ipex_fusion(linear):
221221
quant_state.blocksize,
222222
2,
223223
)
224-
elif ipex_xpu and _ipex_xpu_version_prereq(2, 5):
224+
elif x.device.type == "xpu" and ipex_xpu and _ipex_xpu_version_prereq(2, 5):
225225
quant_state = linear.weight.quant_state
226226
new_weight = linear.weight.data.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
227227

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def get_latest_semver_tag():
2727
tags = subprocess.check_output(["git", "tag"], text=True).splitlines()
2828
semver_tags = [tag for tag in tags if tag.count(".") == 2 and all(part.isdigit() for part in tag.split("."))]
2929
if not semver_tags:
30-
print("No valid semantic version tags found, use 0.0.1 defaultly")
31-
semver_tags = ["0.0.1"]
30+
print("No valid semantic version tags found, use 1.0.0 defaultly")
31+
semver_tags = ["1.0.0"]
3232
return sorted(semver_tags, key=lambda s: list(map(int, s.split("."))))[-1]
3333

3434

0 commit comments

Comments
 (0)