Skip to content

Commit b6e447b

Browse files
ji-huazhongSlightwindSecGinray
authored andcommitted
Add Ascend NPU support for nf4 quant (bitsandbytes-foundation#1422)
* Add npu support for nf4 quant Co-authored-by: Slightwind <[email protected]> Co-authored-by: Ginray <[email protected]> * code format * update * pass lint check and fix typos * add npu to supported devices --------- Co-authored-by: Slightwind <[email protected]> Co-authored-by: Ginray <[email protected]>
1 parent 03bdf88 commit b6e447b

File tree

14 files changed

+581
-29
lines changed

14 files changed

+581
-29
lines changed

CMakeLists.txt

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# For GCC: `cmake -B build . && cmake --build build`
44
# For MSVC: `cmake -B build . && cmake --build build --config Release`
55
# You can also use the following options and variables
6-
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip` or `mps` to select the backend
6+
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, `hip`, `mps` or `npu` to select the backend
77
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
88
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
99
# is whatever CMake finds on your path.
@@ -29,11 +29,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
2929
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
3030
set(MPS_FILES csrc/mps_ops.mm)
3131
set(METAL_FILES csrc/mps_kernels.metal)
32+
set(NPU_FILES csrc/npu_ops.cpp)
3233
# C++ sources are always included
3334
list(APPEND SRC_FILES ${CPP_FILES})
3435

35-
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
36-
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
36+
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, npu)")
37+
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps npu)
3738
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
3839

3940
if(APPLE)
@@ -69,6 +70,11 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
6970
set(BUILD_CUDA OFF)
7071
set(BUILD_HIP OFF)
7172
set(BUILD_MPS ON)
73+
elseif(${COMPUTE_BACKEND} STREQUAL "npu")
74+
set(BUILD_CUDA OFF)
75+
set(BUILD_HIP OFF)
76+
set(BUILD_MPS OFF)
77+
set(BUILD_NPU ON)
7278
else()
7379
set(BUILD_CUDA OFF)
7480
set(BUILD_HIP OFF)
@@ -232,6 +238,33 @@ elseif(BUILD_MPS)
232238
COMMENT "Compiling Metal kernels"
233239
VERBATIM)
234240
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
241+
elseif(BUILD_NPU)
242+
list(APPEND SRC_FILES ${NPU_FILES})
243+
244+
set(SOC_VERSION "Ascend910B4" CACHE STRING "system on chip type")
245+
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE
246+
STRING "ASCEND CAN package installation directory"
247+
)
248+
249+
# ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}.
250+
# ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library
251+
# file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp)
252+
file(GLOB KERNEL_FILES csrc/npu_kernels.cpp)
253+
254+
if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
255+
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
256+
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
257+
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
258+
else()
259+
message(FATAL_ERROR "ascendc_kernel_cmake does not exist ,please check whether the can package is installed")
260+
endif()
261+
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
262+
263+
# ascendc_library use to add kernel file to generate ascendc library
264+
ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES})
265+
266+
string(APPEND BNB_OUTPUT_NAME "_npu")
267+
add_compile_definitions(BUILD_NPU)
235268
else()
236269
string(APPEND BNB_OUTPUT_NAME "_cpu")
237270
set(GPU_SOURCES)
@@ -249,7 +282,11 @@ endif()
249282

250283
set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
251284
add_library(bitsandbytes SHARED ${SRC_FILES})
252-
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
285+
if(BUILD_NPU)
286+
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
287+
else()
288+
target_compile_features(bitsandbytes PUBLIC cxx_std_14)
289+
endif()
253290
target_include_directories(bitsandbytes PUBLIC csrc include)
254291

255292

@@ -306,6 +343,10 @@ if(BUILD_MPS)
306343
add_dependencies(bitsandbytes metallib)
307344
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
308345
endif()
346+
if(BUILD_NPU)
347+
target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17)
348+
target_link_libraries(bitsandbytes PRIVATE $<BUILD_INTERFACE:host_intf_pub> ascendc_kernels_npu)
349+
endif()
309350

310351
if(WIN32)
311352
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")

_typos.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
[default]
44
extend-ignore-re = [
55
"@Ther-nul", # valid Github user
6+
"CANN", # CANN (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU
67
]
78

89
[default.extend-identifiers]
910

1011
[type.py.extend-words]
1112
"BA" = "BA" # used as a commented-out variable in tests
13+
"cann" = "cann" # cann (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for Ascend NPU
14+
1215

1316
[type.cuda.extend-words]
1417
"subtile" = "subtile"

bitsandbytes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
features = {"multi_backend"}
2626
supported_torch_devices = {
2727
"cuda", # includes ROCm
28+
"npu", # Ascend NPU
2829
"xpu", # Intel GPU
2930
"cpu",
3031
"hpu",

bitsandbytes/autograd/_functions.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
519519

520520
# 1. Dequantize
521521
# 2. MatmulnN
522-
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
522+
if A.device.type == "npu":
523+
output = torch.matmul(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t())
524+
if bias is not None:
525+
output += bias
526+
else:
527+
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
523528

524529
# 3. Save state
525530
ctx.state = quant_state
@@ -550,7 +555,10 @@ def backward(ctx, grad_output):
550555
# not supported by PyTorch. TODO: create work-around
551556
# if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
552557
if req_gradA:
553-
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
558+
if grad_output.device.type == "npu":
559+
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype))
560+
else:
561+
grad_A = torch.matmul(grad_output, F.dequantize_4bit(B, ctx.state).to(grad_output.dtype).t())
554562

555563
return grad_A, grad_B, None, grad_bias, None
556564

@@ -586,7 +594,7 @@ def matmul_4bit(
586594
return out
587595
else:
588596
return MatMul4Bit.apply(A, B, out, bias, quant_state)
589-
elif A.numel() == A.shape[-1] and A.requires_grad == False:
597+
elif A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "npu":
590598
if A.shape[-1] % quant_state.blocksize != 0:
591599
warn(
592600
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",

bitsandbytes/backends/cpu_xpu_common.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
gxx_available = False
2626
try:
27-
subprocess.run(["g++", "--version"])
27+
subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output
2828
gxx_available = True
2929
except BaseException:
3030
warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.")
@@ -445,22 +445,13 @@ def dequantize_4bit_impl(
445445
quant_state.ipex = False
446446

447447
# Map nf4 to [-1, 1]
448-
<<<<<<< HEAD
449-
out_uint8 = torch.empty(A.size(0) * 2, dtype=torch.uint8, device=A.device)
450-
out_uint8[::2] = A.bitwise_and(0xF)
451-
out_uint8[1::2] = A.bitwise_right_shift(4)
452-
out_dq = torch.empty(out_uint8.shape, dtype=quant_state.code.dtype, device= quant_state.code.device)
453-
for i in range(len(quant_state.code)):
454-
out_dq[out_uint8 == i] = quant_state.code[i]
455-
=======
456448
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
457449
n = out_dq.numel()
458450
out_dq[::2] = A & 0xF
459451
out_dq[1::2] = A >> 4
460452
# quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue
461453
quant_state.code = quant_state.code.to(quant_state.dtype)
462454
out_dq = quant_state.code[out_dq]
463-
>>>>>>> b2ac423 (Enable XPU and optimize cpu/xpu op (#1418))
464455

465456
# Apply scales
466457
if out_dq.numel() != n:

bitsandbytes/backends/npu.py

Lines changed: 142 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
1+
import ctypes as ct
12
from typing import Literal, Optional, Tuple, Union
23

34
import torch
45

5-
from bitsandbytes.utils import QuantState
6-
7-
from .base import Backend
8-
96
try:
107
# to support Ascend NPU backend
118
import torch_npu # noqa: F401
129
except ImportError:
1310
pass
1411

12+
from bitsandbytes.cextension import lib
13+
from bitsandbytes.functional import (
14+
get_4bit_type,
15+
get_ptr,
16+
)
17+
from bitsandbytes.utils import QuantState
18+
19+
from .base import Backend
20+
21+
22+
def assert_on_npu(tensors):
23+
if not all(t.device.type == "npu" for t in tensors if t is not None):
24+
raise TypeError(
25+
"All input tensors to be on NPU, but found some tensors not be on NPU:\n"
26+
f"{[(t.shape, t.device) if isinstance(t, torch.Tensor) else None for t in tensors]}"
27+
)
28+
return True
29+
1530

1631
class NPUBackend(Backend):
1732
def double_quant(
@@ -75,23 +90,140 @@ def quantize_4bit(
7590
A: torch.Tensor,
7691
absmax: Optional[torch.Tensor] = None,
7792
out: Optional[torch.Tensor] = None,
78-
blocksize=64,
93+
blocksize: Optional[int] = None,
7994
compress_statistics=False,
80-
quant_type: Literal["fp4", "nf4"] = "fp4",
95+
quant_type: Literal["fp4", "nf4"] = "nf4",
8196
quant_storage=torch.uint8,
8297
) -> Tuple[torch.Tensor, QuantState]:
83-
raise NotImplementedError
98+
if quant_type not in ["nf4"]:
99+
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
100+
if compress_statistics:
101+
raise NotImplementedError("compress_statistics is not implemented.")
102+
if blocksize is None:
103+
blocksize = 128
104+
105+
prev_device = torch.npu.current_device()
106+
torch.npu.set_device(A.device)
107+
if A.dtype in [torch.float32, torch.float16, torch.bfloat16]:
108+
data = [
109+
-1.0,
110+
-0.6961928009986877,
111+
-0.5250730514526367,
112+
-0.39491748809814453,
113+
-0.28444138169288635,
114+
-0.18477343022823334,
115+
-0.09105003625154495,
116+
0.0,
117+
0.07958029955625534,
118+
0.16093020141124725,
119+
0.24611230194568634,
120+
0.33791524171829224,
121+
0.44070982933044434,
122+
0.5626170039176941,
123+
0.7229568362236023,
124+
1.0,
125+
]
126+
data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1)
127+
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values
128+
a = A.view(-1, blocksize) / absmax.float()
129+
diff = torch.abs(a.unsqueeze(-1) - data)
130+
out = (torch.argmin(diff, dim=-1) + 8) % 16
131+
out = out.reshape(-1, 2)
132+
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
133+
else:
134+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
135+
assert_on_npu([A, absmax, out])
136+
torch.npu.set_device(prev_device)
137+
138+
code = get_4bit_type(quant_type, device=A.device)
139+
state = QuantState(
140+
absmax=absmax,
141+
shape=A.shape,
142+
dtype=A.dtype,
143+
blocksize=blocksize,
144+
code=code,
145+
quant_type=quant_type,
146+
)
147+
148+
return out, state
84149

85150
def dequantize_4bit(
86151
self,
87152
A: torch.Tensor,
88153
quant_state: Optional[QuantState] = None,
89154
absmax: Optional[torch.Tensor] = None,
90155
out: Optional[torch.Tensor] = None,
91-
blocksize: int = 64,
92-
quant_type: Literal["fp4", "nf4"] = "fp4",
156+
blocksize: Optional[int] = None,
157+
quant_type: Literal["fp4", "nf4"] = "nf4",
93158
) -> torch.Tensor:
94-
raise NotImplementedError
159+
if blocksize is None:
160+
blocksize = 128
161+
supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64]
162+
if blocksize not in supported_blocksizes:
163+
raise ValueError(
164+
f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}"
165+
)
166+
167+
if quant_state is None:
168+
assert absmax is not None and out is not None
169+
quant_state = QuantState(
170+
absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type
171+
)
172+
else:
173+
absmax = quant_state.absmax
174+
175+
if out is None:
176+
out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
177+
178+
n = out.numel()
179+
180+
prev_device = torch.npu.current_device()
181+
torch.npu.set_device(A.device)
182+
assert_on_npu([A, absmax, out])
183+
184+
if quant_state.quant_type not in ["nf4"]:
185+
raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
186+
187+
if out.dtype == torch.float32:
188+
lib.cdequantize_blockwise_fp32_nf4(
189+
get_ptr(A),
190+
get_ptr(absmax),
191+
get_ptr(out),
192+
ct.c_int(quant_state.blocksize),
193+
ct.c_int(n),
194+
torch.npu.current_stream(),
195+
)
196+
elif out.dtype == torch.float16:
197+
lib.cdequantize_blockwise_fp16_nf4(
198+
get_ptr(A),
199+
get_ptr(absmax),
200+
get_ptr(out),
201+
ct.c_int(quant_state.blocksize),
202+
ct.c_int(n),
203+
torch.npu.current_stream(),
204+
)
205+
elif out.dtype == torch.bfloat16:
206+
# bf16: bf16 -> fp32 -> op -> fp32 -> bf16
207+
absmax = absmax.to(torch.float32)
208+
out = out.to(torch.float32)
209+
lib.cdequantize_blockwise_fp32_nf4(
210+
get_ptr(A),
211+
get_ptr(absmax),
212+
get_ptr(out),
213+
ct.c_int(quant_state.blocksize),
214+
ct.c_int(n),
215+
torch.npu.current_stream(),
216+
)
217+
out = out.to(torch.bfloat16)
218+
else:
219+
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
220+
torch.npu.set_device(prev_device)
221+
is_transposed = True if A.shape[0] == 1 else False
222+
223+
if is_transposed:
224+
return out.t()
225+
else:
226+
return out
95227

96228
def gemv_4bit(
97229
self,

bitsandbytes/cextension.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
2727
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch
28+
from bitsandbytes.npu_specs import get_npu_specs
2829

2930
logger = logging.getLogger(__name__)
3031

@@ -100,6 +101,10 @@ def get_native_library() -> BNBNativeLibrary:
100101
binary_path = cuda_binary_path
101102
else:
102103
logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path)
104+
npu_specs = get_npu_specs()
105+
if npu_specs:
106+
binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}"
107+
103108
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
104109
dll = ct.cdll.LoadLibrary(str(binary_path))
105110

0 commit comments

Comments
 (0)