Skip to content
Open
57 changes: 55 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
set(NPU_FILES csrc/npu_ops.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})

set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu, npu)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu npu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)

if(APPLE)
Expand All @@ -51,20 +52,23 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
set(BUILD_CUDA ON)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_NPU OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
if(APPLE)
message(FATAL_ERROR "HIP is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP ON)
set(BUILD_MPS OFF)
set(BUILD_NPU OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
set(BUILD_NPU OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
if(APPLE)
message(FATAL_ERROR "XPU is not supported on macOS" )
Expand All @@ -73,11 +77,22 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU ON)
set(BUILD_NPU OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "npu")
if(APPLE)
message(FATAL_ERROR "NPU is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU OFF)
set(BUILD_NPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
set(BUILD_XPU OFF)
set(BUILD_NPU OFF)
set(BUILD_CPU ON)
endif()

Expand Down Expand Up @@ -250,6 +265,40 @@ elseif(BUILD_XPU)
if(WIN32)
set(CMAKE_CXX_COMPILER icx)
endif()
elseif(BUILD_NPU)
list(APPEND SRC_FILES ${NPU_FILES})
execute_process(
COMMAND bash -c "npu-smi info|awk -F' ' 'NF > 0 && NR==7 {print $3}'"
OUTPUT_VARIABLE npu_info
RESULT_VARIABLE npu_result
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if("${npu_info}" STREQUAL "" OR ${npu_result})
message(FATAL_ERROR "Auto-detech ascend soc type failed, please specify manually or check ascend device working normally.")
endif()

set(SOC_VERSION "Ascend${npu_info}" CACHE STRING "system on chip type")
set(ASCEND_CANN_PACKAGE_PATH $ENV{ASCEND_HOME_PATH} CACHE STRING "ASCEND CANN package installation directory")

# ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}.
# ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library
# file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/npu_kernels.cpp)
file(GLOB KERNEL_FILES csrc/npu_kernels.cpp)

if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/tools/tikcpp/ascendc_kernel_cmake)
else()
message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the can package is installed")
endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)

# ascendc_library use to add kernel file to generate ascendc library
ascendc_library(ascendc_kernels_npu STATIC ${KERNEL_FILES})

string(APPEND BNB_OUTPUT_NAME "_npu")
add_compile_definitions(BUILD_NPU)
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
Expand Down Expand Up @@ -355,6 +404,10 @@ if(BUILD_XPU)
target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})

endif()
if(BUILD_NPU)
target_compile_options(bitsandbytes PRIVATE -O2 -std=c++17)
target_link_libraries(bitsandbytes PRIVATE $<BUILD_INTERFACE:host_intf_pub> ascendc_kernels_npu)
endif()

if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ bitsandbytes has the following minimum requirements for all platforms:
<td>✅</td>
<td>✅</td>
</tr>
<tr>
<td></td>
<td>🟧 Ascend NPU <br><code>npu</code></td>
<td>Atlas 800T A2+</td>
<td>❌</td>
<td>✅</td>
<td>❌</td>
</tr>
<tr>
<td colspan="6">🪟 <strong>Windows 11 / Windows Server 2019+</strong></td>
</tr>
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
if hasattr(torch, "hpu") and torch.hpu.is_available():
from .backends.hpu import ops as hpu_ops

if importlib.util.find_spec("torch") and importlib.util.find_spec("torch_npu"):
from .backends.npu import ops as npu_ops


def _import_backends():
"""
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def matmul_4bit(
if A.device.type == "cpu":
quant_state.dtype = A.dtype

if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu" and A.device.type != "npu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
Expand Down
Empty file.
124 changes: 124 additions & 0 deletions bitsandbytes/backends/npu/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import ctypes as ct
from collections.abc import Sequence

import torch

from bitsandbytes.functional import get_ptr

from ..._ops import register_kernel
from ...cextension import lib
from ..utils import _NF4_QUANT_TABLE


@register_kernel("bitsandbytes::quantize_4bit", "npu")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on NPU, got {quant_type}")
n = A.numel()

global _NF4_QUANT_TABLE
if _NF4_QUANT_TABLE.device != A.device:
_NF4_QUANT_TABLE = _NF4_QUANT_TABLE.to(A.device)

# TODO: Support when weight matrix is not divisible by blocksize
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")

# Process tensor in chunks to avoid high memory usage from large intermediate tensors
# (e.g., during broadcasting with FP32 quant table)
chunks_absmax = []
chunks_out = []
total_blocks = A.numel() // blocksize
chunks = 8 if A.numel() > 1024 * 1024 else 1
chunksize = (total_blocks + chunks - 1) // chunks

for i in range(chunks):
start = i * chunksize * blocksize
end = min((i + 1) * chunksize * blocksize, A.numel())
chunk_data = A.view(-1)[start:end].view(-1, blocksize)

absmax = chunk_data.abs().max(dim=1, keepdim=True).values
chunks_absmax.append(absmax)

a = chunk_data / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - _NF4_QUANT_TABLE)
out = (torch.argmin(diff, dim=-1) + 8) % 16

out = out.reshape(-1, 2)
# Pack 4-bit values in NPU-compatible order (low nibble first) to match NPU-specific unpacking logic;
# differs from CUDA's packing
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
chunks_out.append(out)

absmax = torch.cat(chunks_absmax, dim=0)
packed = torch.cat(chunks_out, dim=0).reshape(-1, 1)
return packed, absmax


@register_kernel("bitsandbytes::dequantize_4bit", "npu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
out = torch.empty(shape, dtype=dtype, device=A.device)
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out


@register_kernel("bitsandbytes::dequantize_4bit.out", "npu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
_dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["nf4"])
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
if out.dtype == torch.bfloat16:
# bf16: bf16 -> fp32 -> op -> fp32 -> bf16
absmax = absmax.to(torch.float32)
out_fp32 = torch.empty(out.shape, dtype=torch.float32, device=out.device)
else:
out_fp32 = out

args = (
get_ptr(A),
get_ptr(absmax),
get_ptr(out_fp32),
ct.c_int(blocksize),
ct.c_int(out.numel()),
torch.npu.current_stream(),
)

if out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_fp32_nf4(*args)
out.copy_(out_fp32.to(torch.bfloat16))
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32_nf4(*args)
13 changes: 12 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
import re
from typing import Optional
import importlib

import torch

Expand Down Expand Up @@ -48,6 +49,13 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
return PACKAGE_DIR / library_name


def is_npu_available() -> bool:
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False
return True


class BNBNativeLibrary:
_lib: ct.CDLL
compiled_with_cuda = False
Expand Down Expand Up @@ -288,7 +296,8 @@ def get_native_library() -> BNBNativeLibrary:
raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")

binary_path = cuda_binary_path

elif is_npu_available():
binary_path = PACKAGE_DIR / f"libbitsandbytes_npu{DYNAMIC_LIBRARY_SUFFIX}"
if torch._C._has_xpu:
binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}"

Expand All @@ -311,6 +320,8 @@ def get_native_library() -> BNBNativeLibrary:
if torch.version.hip:
HIP_ENVIRONMENT = True
BNB_BACKEND = "ROCm"
elif is_npu_available():
BNB_BACKEND = "NPU"
elif torch.cuda.is_available():
BNB_BACKEND = "CUDA"
elif torch._C._has_xpu:
Expand Down
Loading