Skip to content

Commit b09241c

Browse files
author
bghira
committed
mps: add nf4 dequantize/quantize kernel
1 parent fd9934c commit b09241c

File tree

9 files changed

+1057
-140
lines changed

9 files changed

+1057
-140
lines changed

CMakeLists.txt

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,19 @@ elseif(BUILD_MPS)
226226
string(APPEND BNB_OUTPUT_NAME "_mps")
227227
add_compile_definitions(BUILD_MPS)
228228
file(MAKE_DIRECTORY "build")
229-
add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib"
230-
COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES}
231-
COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib"
229+
set(METAL_AIR "${CMAKE_BINARY_DIR}/bitsandbytes.air")
230+
set(METAL_LIB "${PROJECT_SOURCE_DIR}/bitsandbytes/bitsandbytes.metallib")
231+
set(METAL_SOURCES "")
232+
foreach(METAL_FILE ${METAL_FILES})
233+
list(APPEND METAL_SOURCES "${PROJECT_SOURCE_DIR}/${METAL_FILE}")
234+
endforeach()
235+
add_custom_command(OUTPUT "${METAL_LIB}"
236+
COMMAND xcrun metal -c ${METAL_SOURCES} -o "${METAL_AIR}"
237+
COMMAND xcrun metallib "${METAL_AIR}" -o "${METAL_LIB}"
232238
DEPENDS "${METAL_FILES}"
233239
COMMENT "Compiling Metal kernels"
234240
VERBATIM)
235-
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
241+
add_custom_target(metallib DEPENDS "${METAL_LIB}")
236242
elseif(BUILD_XPU)
237243
list(APPEND SRC_FILES ${XPU_FILES})
238244
string(APPEND BNB_OUTPUT_NAME "_xpu")
@@ -257,10 +263,57 @@ if(MSVC)
257263
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast")
258264
endif()
259265

266+
find_package(Python3 COMPONENTS Interpreter Development)
267+
message(STATUS "Python3 found: ${Python3_FOUND}")
268+
269+
if(NOT Torch_DIR)
270+
execute_process(
271+
COMMAND ${Python3_EXECUTABLE} -c "import torch, pathlib; print(pathlib.Path(torch.__file__).resolve().parent / 'share/cmake/Torch')"
272+
OUTPUT_VARIABLE Torch_DIR
273+
OUTPUT_STRIP_TRAILING_WHITESPACE
274+
)
275+
endif()
276+
message(STATUS "Torch_DIR=${Torch_DIR}")
277+
find_package(Torch REQUIRED CONFIG)
278+
260279
set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
261280
add_library(bitsandbytes SHARED ${SRC_FILES})
262281
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
263282
target_include_directories(bitsandbytes PUBLIC csrc include)
283+
if(Python3_FOUND)
284+
message(STATUS "Python include dirs: ${Python3_INCLUDE_DIRS}")
285+
target_include_directories(bitsandbytes PRIVATE ${Python3_INCLUDE_DIRS})
286+
execute_process(
287+
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_paths()['include'])"
288+
OUTPUT_VARIABLE PYTHON_SYSTEM_INCLUDE
289+
OUTPUT_STRIP_TRAILING_WHITESPACE
290+
)
291+
if(PYTHON_SYSTEM_INCLUDE)
292+
target_include_directories(bitsandbytes PRIVATE ${PYTHON_SYSTEM_INCLUDE})
293+
endif()
294+
execute_process(
295+
COMMAND ${Python3_EXECUTABLE} -c "import torch\nfrom torch.utils.cpp_extension import include_paths\nprint(';'.join(include_paths()))"
296+
OUTPUT_VARIABLE TORCH_INCLUDE_DIRS
297+
OUTPUT_STRIP_TRAILING_WHITESPACE
298+
ERROR_QUIET
299+
)
300+
if(TORCH_INCLUDE_DIRS)
301+
string(REPLACE "\\n" ";" TORCH_INCLUDE_DIRS "${TORCH_INCLUDE_DIRS}")
302+
target_include_directories(bitsandbytes PRIVATE ${TORCH_INCLUDE_DIRS})
303+
endif()
304+
execute_process(
305+
COMMAND ${Python3_EXECUTABLE} -c "import torch\nfrom torch.utils.cpp_extension import library_paths\nprint(';'.join(library_paths()))"
306+
OUTPUT_VARIABLE TORCH_LIBRARY_DIRS
307+
OUTPUT_STRIP_TRAILING_WHITESPACE
308+
ERROR_QUIET
309+
)
310+
if(TORCH_LIBRARY_DIRS)
311+
string(REPLACE "\\n" ";" TORCH_LIBRARY_DIRS "${TORCH_LIBRARY_DIRS}")
312+
target_link_directories(bitsandbytes PRIVATE ${TORCH_LIBRARY_DIRS})
313+
target_link_libraries(bitsandbytes PRIVATE torch torch_cpu torch_python c10)
314+
endif()
315+
target_link_libraries(bitsandbytes PRIVATE ${Python3_LIBRARIES})
316+
endif()
264317

265318

266319
if(BUILD_CUDA)
@@ -308,7 +361,8 @@ if(BUILD_HIP)
308361
endif()
309362
if(BUILD_MPS)
310363
add_dependencies(bitsandbytes metallib)
311-
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
364+
target_compile_options(bitsandbytes PRIVATE "-fno-objc-arc")
365+
target_link_libraries(bitsandbytes PRIVATE objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
312366
endif()
313367
if(BUILD_XPU)
314368
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")

bitsandbytes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
if hasattr(torch, "xpu") and torch.xpu.is_available():
3939
from .backends.xpu import ops as xpu_ops
4040

41+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
42+
from .backends.mps import ops as mps_ops
43+
4144
if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
4245
# In case not automatically imported
4346
import habana_frameworks.torch

bitsandbytes/backends/default/ops.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,18 @@
22
from math import prod, sqrt
33
from typing import Optional
44

5+
import importlib.util
6+
57
import torch
68

9+
_HAS_TRITON = importlib.util.find_spec("triton") is not None
10+
11+
12+
def _maybe_compile(fn):
13+
if not _HAS_TRITON:
14+
return fn
15+
return torch.compile(fn)
16+
717
from ..._ops import register_kernel
818
from ..utils import CODE
919

@@ -321,7 +331,7 @@ def _(
321331
}
322332

323333

324-
@torch.compile
334+
@_maybe_compile
325335
def _optimizer_precondition_32bit(
326336
g: torch.Tensor,
327337
p: torch.Tensor,
@@ -382,7 +392,7 @@ def _optimizer_precondition_32bit(
382392
unorm_vec.add_(total_norm)
383393

384394

385-
@torch.compile
395+
@_maybe_compile
386396
def _optimizer_update_32bit(
387397
g: torch.Tensor,
388398
p: torch.Tensor,

bitsandbytes/backends/mps/ops.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from __future__ import annotations
2+
3+
import ctypes as ct
4+
from typing import Sequence, Tuple
5+
6+
import torch
7+
8+
from ..._ops import register_kernel
9+
from ...cextension import lib
10+
_ALLOWED_BLOCKS = (64, 128, 256, 512, 1024, 2048, 4096)
11+
_SUPPORTED_DTYPES = (torch.float16, torch.float32)
12+
13+
14+
lib.cquantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
15+
lib.cquantize_blockwise_fp16_nf4_tensor.restype = None
16+
lib.cquantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
17+
lib.cquantize_blockwise_fp32_nf4_tensor.restype = None
18+
lib.cdequantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
19+
lib.cdequantize_blockwise_fp16_nf4_tensor.restype = None
20+
lib.cdequantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
21+
lib.cdequantize_blockwise_fp32_nf4_tensor.restype = None
22+
23+
24+
def _quantize_nf4(
25+
A: torch.Tensor, blocksize: int, quant_storage: torch.dtype
26+
) -> Tuple[torch.Tensor, torch.Tensor]:
27+
torch._check(blocksize in _ALLOWED_BLOCKS)
28+
torch._check(quant_storage == torch.uint8, lambda: "Only uint8 storage is supported for NF4 on MPS.")
29+
30+
A = A.contiguous()
31+
n = A.numel()
32+
blocks = -(n // -blocksize)
33+
34+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
35+
out = torch.empty(((n + 1) // 2, 1), device=A.device, dtype=quant_storage)
36+
37+
if A.dtype == torch.float16:
38+
lib.cquantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
39+
elif A.dtype == torch.float32:
40+
lib.cquantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
41+
else:
42+
torch._check(False, lambda: f"NF4 quantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {A.dtype}")
43+
44+
return out, absmax
45+
46+
47+
def _dequantize_nf4(
48+
A: torch.Tensor,
49+
absmax: torch.Tensor,
50+
blocksize: int,
51+
dtype: torch.dtype,
52+
out: torch.Tensor,
53+
) -> None:
54+
torch._check(blocksize in _ALLOWED_BLOCKS)
55+
56+
A = A.contiguous()
57+
absmax = absmax.contiguous()
58+
torch._check(out.is_contiguous(), lambda: "Output tensor must be contiguous for NF4 dequantization on MPS.")
59+
60+
if dtype == torch.float16:
61+
lib.cdequantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
62+
elif dtype == torch.float32:
63+
lib.cdequantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
64+
else:
65+
torch._check(False, lambda: f"NF4 dequantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {dtype}")
66+
67+
68+
@register_kernel("bitsandbytes::quantize_4bit", "mps")
69+
def _(
70+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
71+
) -> Tuple[torch.Tensor, torch.Tensor]:
72+
if quant_type != "nf4" or A.dtype not in _SUPPORTED_DTYPES:
73+
return torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, quant_storage)
74+
return _quantize_nf4(A, blocksize, quant_storage)
75+
76+
77+
@register_kernel("bitsandbytes::dequantize_4bit", "mps")
78+
def _(
79+
A: torch.Tensor,
80+
absmax: torch.Tensor,
81+
blocksize: int,
82+
quant_type: str,
83+
shape: Sequence[int],
84+
dtype: torch.dtype,
85+
) -> torch.Tensor:
86+
if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES:
87+
return torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)
88+
out = torch.empty(shape, dtype=dtype, device=A.device)
89+
_dequantize_nf4(A, absmax, blocksize, dtype, out)
90+
return out
91+
92+
93+
@register_kernel("bitsandbytes::dequantize_4bit.out", "mps")
94+
def _(
95+
A: torch.Tensor,
96+
absmax: torch.Tensor,
97+
blocksize: int,
98+
quant_type: str,
99+
shape: Sequence[int],
100+
dtype: torch.dtype,
101+
out: torch.Tensor,
102+
) -> None:
103+
if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES:
104+
torch.ops.bitsandbytes.dequantize_4bit.out.default(
105+
A,
106+
absmax,
107+
blocksize,
108+
quant_type,
109+
shape,
110+
dtype,
111+
out,
112+
)
113+
return
114+
115+
torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}")
116+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
117+
_dequantize_nf4(A, absmax, blocksize, dtype, out)

bitsandbytes/cextension.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,9 @@ def get_native_library() -> BNBNativeLibrary:
283283

284284
binary_path = cuda_binary_path
285285

286-
if torch._C._has_xpu:
286+
if BNB_BACKEND == "MPS":
287+
binary_path = PACKAGE_DIR / f"libbitsandbytes_mps{DYNAMIC_LIBRARY_SUFFIX}"
288+
elif torch._C._has_xpu:
287289
binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}"
288290

289291
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
@@ -306,6 +308,8 @@ def get_native_library() -> BNBNativeLibrary:
306308
BNB_BACKEND = "ROCm"
307309
elif torch.cuda.is_available():
308310
BNB_BACKEND = "CUDA"
311+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
312+
BNB_BACKEND = "MPS"
309313
elif torch._C._has_xpu:
310314
BNB_BACKEND = "XPU"
311315

0 commit comments

Comments
 (0)