Skip to content

Commit e30021e

Browse files
committed
fix MTE out of range error
Signed-off-by: SlightwindSec <[email protected]>
1 parent 607ecf9 commit e30021e

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

CMakeLists.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,7 @@ endif()
308308

309309
set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
310310
add_library(bitsandbytes SHARED ${SRC_FILES})
311-
if(BUILD_NPU)
312-
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
313-
else()
314-
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
315-
endif()
311+
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
316312
target_include_directories(bitsandbytes PUBLIC csrc include)
317313

318314

bitsandbytes/backends/npu/ops.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,25 @@ def _dequantize_4bit_impl(
9999
dtype in [torch.bfloat16, torch.float16, torch.float32],
100100
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
101101
)
102+
if out.dtype == torch.bfloat16:
103+
# bf16: bf16 -> fp32 -> op -> fp32 -> bf16
104+
absmax = absmax.to(torch.float32)
105+
out_fp32 = torch.empty(out.shape, dtype=torch.float32, device=out.device)
106+
else:
107+
out_fp32 = out
108+
102109
args = (
103110
get_ptr(A),
104111
get_ptr(absmax),
105-
get_ptr(out),
112+
get_ptr(out_fp32),
106113
ct.c_int(blocksize),
107114
ct.c_int(out.numel()),
108115
torch.npu.current_stream(),
109116
)
110117

111118
if out.dtype == torch.bfloat16:
112-
# bf16: bf16 -> fp32 -> op -> fp32 -> bf16
113-
absmax = absmax.to(torch.float32)
114-
out = out.to(torch.float32)
115119
lib.cdequantize_blockwise_fp32_nf4(*args)
116-
out = out.to(torch.bfloat16)
120+
out.copy_(out_fp32.to(torch.bfloat16))
117121
elif out.dtype == torch.float16:
118122
lib.cdequantize_blockwise_fp16_nf4(*args)
119123
elif out.dtype == torch.float32:

0 commit comments

Comments
 (0)