Skip to content

Commit 622c811

Browse files
committed
Added checks for Lazy mode
1 parent 6f406dd commit 622c811

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import subprocess
22
from typing import Optional
33
import warnings
4+
import os
45

56
import torch
67

@@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor):
5556

5657
def _maybe_torch_compile(func):
5758
# torch.compile requires g++ and pytorch >= 2.0
58-
if gxx_available and _torch_version_prereq(2, 0):
59+
if gxx_available and _torch_version_prereq(2, 0) and os.getenv('PT_HPU_LAZY_MODE',1)==0:
5960
options = {}
6061
# fx_graph_cache requires pytorch >= 2.2
6162
if _torch_version_prereq(2, 2):
@@ -277,7 +278,7 @@ def mm_dequant_impl(
277278
}
278279

279280

280-
# @_maybe_torch_compile
281+
@_maybe_torch_compile
281282
def quantize_4bit_impl(
282283
A: Tensor,
283284
absmax: Tensor = None,
@@ -342,7 +343,7 @@ def quantize_4bit_impl(
342343
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
343344
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
344345
# map [-1, 1] to nf4/fp4
345-
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=scaled_A.device)
346+
out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device)
346347
if quant_type == "nf4":
347348
for i in range(len(NF4_QUANT_TABLE)):
348349
out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i
@@ -372,8 +373,7 @@ def quantize_4bit_impl(
372373

373374
return out.unsqueeze(0), state
374375

375-
376-
#@_maybe_torch_compile
376+
@_maybe_torch_compile
377377
def dequantize_4bit_impl(
378378
A: Tensor,
379379
quant_state=None,

0 commit comments

Comments
 (0)