diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c133e09f..08292e6dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -226,13 +226,19 @@ elseif(BUILD_MPS) string(APPEND BNB_OUTPUT_NAME "_mps") add_compile_definitions(BUILD_MPS) file(MAKE_DIRECTORY "build") - add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib" - COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES} - COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib" + set(METAL_AIR "${CMAKE_BINARY_DIR}/bitsandbytes.air") + set(METAL_LIB "${PROJECT_SOURCE_DIR}/bitsandbytes/bitsandbytes.metallib") + set(METAL_SOURCES "") + foreach(METAL_FILE ${METAL_FILES}) + list(APPEND METAL_SOURCES "${PROJECT_SOURCE_DIR}/${METAL_FILE}") + endforeach() + add_custom_command(OUTPUT "${METAL_LIB}" + COMMAND xcrun metal -c ${METAL_SOURCES} -o "${METAL_AIR}" + COMMAND xcrun metallib "${METAL_AIR}" -o "${METAL_LIB}" DEPENDS "${METAL_FILES}" COMMENT "Compiling Metal kernels" VERBATIM) - add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") + add_custom_target(metallib DEPENDS "${METAL_LIB}") elseif(BUILD_XPU) list(APPEND SRC_FILES ${XPU_FILES}) string(APPEND BNB_OUTPUT_NAME "_xpu") @@ -257,10 +263,57 @@ if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") endif() +find_package(Python3 COMPONENTS Interpreter Development) +message(STATUS "Python3 found: ${Python3_FOUND}") + +if(NOT Torch_DIR) + execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import torch, pathlib; print(pathlib.Path(torch.__file__).resolve().parent / 'share/cmake/Torch')" + OUTPUT_VARIABLE Torch_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ) +endif() +message(STATUS "Torch_DIR=${Torch_DIR}") +find_package(Torch REQUIRED CONFIG) + set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX) add_library(bitsandbytes SHARED ${SRC_FILES}) target_compile_features(bitsandbytes PUBLIC cxx_std_17) target_include_directories(bitsandbytes PUBLIC csrc include) +if(Python3_FOUND) + message(STATUS "Python include dirs: ${Python3_INCLUDE_DIRS}") + target_include_directories(bitsandbytes PRIVATE ${Python3_INCLUDE_DIRS}) + execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_paths()['include'])" + OUTPUT_VARIABLE PYTHON_SYSTEM_INCLUDE + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(PYTHON_SYSTEM_INCLUDE) + target_include_directories(bitsandbytes PRIVATE ${PYTHON_SYSTEM_INCLUDE}) + endif() + execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import torch\nfrom torch.utils.cpp_extension import include_paths\nprint(';'.join(include_paths()))" + OUTPUT_VARIABLE TORCH_INCLUDE_DIRS + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if(TORCH_INCLUDE_DIRS) + string(REPLACE "\\n" ";" TORCH_INCLUDE_DIRS "${TORCH_INCLUDE_DIRS}") + target_include_directories(bitsandbytes PRIVATE ${TORCH_INCLUDE_DIRS}) + endif() + execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import torch\nfrom torch.utils.cpp_extension import library_paths\nprint(';'.join(library_paths()))" + OUTPUT_VARIABLE TORCH_LIBRARY_DIRS + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if(TORCH_LIBRARY_DIRS) + string(REPLACE "\\n" ";" TORCH_LIBRARY_DIRS "${TORCH_LIBRARY_DIRS}") + target_link_directories(bitsandbytes PRIVATE ${TORCH_LIBRARY_DIRS}) + target_link_libraries(bitsandbytes PRIVATE torch torch_cpu torch_python c10) + endif() + target_link_libraries(bitsandbytes PRIVATE ${Python3_LIBRARIES}) +endif() if(BUILD_CUDA) @@ -308,7 +361,8 @@ if(BUILD_HIP) endif() if(BUILD_MPS) add_dependencies(bitsandbytes metallib) - target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") + target_compile_options(bitsandbytes PRIVATE "-fno-objc-arc") + target_link_libraries(bitsandbytes PRIVATE objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") endif() if(BUILD_XPU) set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'") diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 8bea82fb3..4d7c94abb 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -38,6 +38,9 @@ if hasattr(torch, "xpu") and torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops +if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + from .backends.mps import ops as mps_ops + if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"): # In case not automatically imported import habana_frameworks.torch diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 067347d47..965206a9c 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -2,8 +2,18 @@ from math import prod, sqrt from typing import Optional +import importlib.util + import torch +_HAS_TRITON = importlib.util.find_spec("triton") is not None + + +def _maybe_compile(fn): + if not _HAS_TRITON: + return fn + return torch.compile(fn) + from ..._ops import register_kernel from ..utils import CODE @@ -321,7 +331,7 @@ def _( } -@torch.compile +@_maybe_compile def _optimizer_precondition_32bit( g: torch.Tensor, p: torch.Tensor, @@ -382,7 +392,7 @@ def _optimizer_precondition_32bit( unorm_vec.add_(total_norm) -@torch.compile +@_maybe_compile def _optimizer_update_32bit( g: torch.Tensor, p: torch.Tensor, diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py new file mode 100644 index 000000000..28ed29fcb --- /dev/null +++ b/bitsandbytes/backends/mps/ops.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import ctypes as ct +from typing import Sequence, Tuple + +import torch + +from ..._ops import register_kernel +from ...cextension import lib +_ALLOWED_BLOCKS = (64, 128, 256, 512, 1024, 2048, 4096) +_SUPPORTED_DTYPES = (torch.float16, torch.float32) + + +lib.cquantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] +lib.cquantize_blockwise_fp16_nf4_tensor.restype = None +lib.cquantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] +lib.cquantize_blockwise_fp32_nf4_tensor.restype = None +lib.cdequantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] +lib.cdequantize_blockwise_fp16_nf4_tensor.restype = None +lib.cdequantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] +lib.cdequantize_blockwise_fp32_nf4_tensor.restype = None + + +def _quantize_nf4( + A: torch.Tensor, blocksize: int, quant_storage: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in _ALLOWED_BLOCKS) + torch._check(quant_storage == torch.uint8, lambda: "Only uint8 storage is supported for NF4 on MPS.") + + A = A.contiguous() + n = A.numel() + blocks = -(n // -blocksize) + + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // 2, 1), device=A.device, dtype=quant_storage) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) + else: + torch._check(False, lambda: f"NF4 quantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {A.dtype}") + + return out, absmax + + +def _dequantize_nf4( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in _ALLOWED_BLOCKS) + + A = A.contiguous() + absmax = absmax.contiguous() + torch._check(out.is_contiguous(), lambda: "Output tensor must be contiguous for NF4 dequantization on MPS.") + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) + else: + torch._check(False, lambda: f"NF4 dequantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {dtype}") + + +@register_kernel("bitsandbytes::quantize_4bit", "mps") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + if quant_type != "nf4" or A.dtype not in _SUPPORTED_DTYPES: + return torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, quant_storage) + return _quantize_nf4(A, blocksize, quant_storage) + + +@register_kernel("bitsandbytes::dequantize_4bit", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES: + return torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype) + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_nf4(A, absmax, blocksize, dtype, out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES: + torch.ops.bitsandbytes.dequantize_4bit.out.default( + A, + absmax, + blocksize, + quant_type, + shape, + dtype, + out, + ) + return + + torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_nf4(A, absmax, blocksize, dtype, out) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 2eb584a66..3387a17e8 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -283,7 +283,9 @@ def get_native_library() -> BNBNativeLibrary: binary_path = cuda_binary_path - if torch._C._has_xpu: + if BNB_BACKEND == "MPS": + binary_path = PACKAGE_DIR / f"libbitsandbytes_mps{DYNAMIC_LIBRARY_SUFFIX}" + elif torch._C._has_xpu: binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}" logger.debug(f"Loading bitsandbytes native library from: {binary_path}") @@ -306,6 +308,8 @@ def get_native_library() -> BNBNativeLibrary: BNB_BACKEND = "ROCm" elif torch.cuda.is_available(): BNB_BACKEND = "CUDA" +elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + BNB_BACKEND = "MPS" elif torch._C._has_xpu: BNB_BACKEND = "XPU" diff --git a/csrc/mps_kernels.metal b/csrc/mps_kernels.metal index 63b3bf78c..991cc84be 100644 --- a/csrc/mps_kernels.metal +++ b/csrc/mps_kernels.metal @@ -1,117 +1,291 @@ +#include #include +#include using namespace metal; -#define HLF_MAX 65504 -#define TH 1024 -#define NUM 4 -#define NUM_BLOCK 4096 - -template -static unsigned char quantize_scalar( - float rand, - device float* code, - float x) -{ - int pivot = 127; - int upper_pivot = 255; - int lower_pivot = 0; - - float lower = -1.0f; - float upper = 1.0f; - - float val = code[pivot]; - // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 64; i > 0; i>>=1) - { - if(x > val) - { - lower_pivot = pivot; - lower = val; - pivot+=i; +constant float nf4_dequant_lut[16] = { + -1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + +inline uchar quantize_nf4_scalar(float x) { + if (x > 0.03979014977812767f) { + if (x > 0.3893125355243683f) { + if (x > 0.6427869200706482f) { + if (x > 0.8614784181118011f) { + return 0b1111; + } else { + return 0b1110; + } + } else { + if (x > 0.5016634166240692f) { + return 0b1101; + } else { + return 0b1100; + } + } + } else { + if (x > 0.2035212516784668f) { + if (x > 0.2920137718319893f) { + return 0b1011; + } else { + return 0b1010; + } + } else { + if (x > 0.1202552504837513f) { + return 0b1001; + } else { + return 0b1000; + } + } } - else - { - upper_pivot = pivot; - upper = val; - pivot-=i; + } else { + if (x > -0.33967943489551544f) { + if (x > -0.13791173323988914f) { + if (x > -0.045525018125772476f) { + return 0b0111; + } else { + return 0b0110; + } + } else { + if (x > -0.23460740596055984f) { + return 0b0101; + } else { + return 0b0100; + } + } + } else { + if (x > -0.6106329262256622f) { + if (x > -0.4599952697753906f) { + return 0b0011; + } else { + return 0b0010; + } + } else { + if (x > -0.8480964004993439f) { + return 0b0001; + } else { + return 0b0000; + } + } } - val = code[pivot]; } +} - if(upper_pivot == 255) - upper = code[upper_pivot]; - if(lower_pivot == 0) - lower = code[lower_pivot]; - - if(!STOCHASTIC) - { - if(x > val) - { - float midpoint = (upper+val)*0.5f; - if(x > midpoint) - { - return upper_pivot; - } - else - return pivot; - } - else - { - float midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } - } - else - { - if(x > val) - { - float dist_to_upper = fabs(upper-x); - float dist_full = upper-val; - if(rand >= dist_to_upper/dist_full) return upper_pivot; - else return pivot; - } - else - { - float dist_to_lower = fabs(lower-x); - float dist_full = val-lower; - if(rand >= dist_to_lower/dist_full) return lower_pivot; - else return pivot; - } - } +inline float dequantize_nf4_scalar(uchar code) { + return nf4_dequant_lut[code & 0x0F]; } -kernel void quantize(device float* code [[buffer(0)]], - device float* A [[buffer(1)]], - device uchar* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint id [[thread_position_in_grid]]) { - const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); - uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; - const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); +template +inline float load_value(const device T* src, uint index) { + return static_cast(src[index]); +} - float vals[NUM]; - uchar qvals[NUM]; +template <> +inline float load_value(const device half* src, uint index) { + return static_cast(src[index]); +} - for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { - valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; +template +inline void store_value(device T* dst, uint index, float value) { + dst[index] = static_cast(value); +} - threadgroup_barrier(mem_flags::mem_threadgroup); +template <> +inline void store_value(device half* dst, uint index, float value) { + dst[index] = static_cast(value); +} + +struct BlockParams { + uint n; + uint blocksize; + uint threads_per_group; +}; - for (uint j = 0; j < valid_items; j++) { - vals[j] = A[i + j]; +template +inline void quantize_nf4_impl( + const device T* input, + device float* absmax, + device uchar* output, + constant BlockParams& params, + threadgroup float* shared_vals, + threadgroup float& shared_scale, + threadgroup float& shared_absmax, + uint tid, + uint threads_per_group, + uint block_idx) { + + const uint block_start = block_idx * params.blocksize; + if (block_start >= params.n) { + return; } - for (uint j = 0; j < valid_items; j++) { - qvals[j] = quantize_scalar(0.0f, code, vals[j]); + const uint block_end = min(block_start + params.blocksize, params.n); + const uint block_length = block_end - block_start; + + float local_max = 0.0f; + for (uint idx = tid; idx < block_length; idx += threads_per_group) { + const uint global_idx = block_start + idx; + const float value = fabs(load_value(input, global_idx)); + local_max = fmax(local_max, value); + } + + shared_vals[tid] = local_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint stride = threads_per_group >> 1; stride > 0; stride >>= 1) { + if (tid < stride) { + shared_vals[tid] = fmax(shared_vals[tid], shared_vals[tid + stride]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); } + if (tid == 0) { + shared_absmax = fmax(shared_vals[0], 0.0f); + shared_scale = shared_absmax > 0.0f ? 1.0f / shared_absmax : 0.0f; + absmax[block_idx] = shared_absmax; + } threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint j = 0; j < valid_items; j++) { - out[i + j] = qvals[j]; + const float scale = shared_scale; + const uint pair_stride = threads_per_group * 2; + for (uint local_idx = tid * 2; local_idx < block_length; local_idx += pair_stride) { + const uint global_idx = block_start + local_idx; + const float v0 = load_value(input, global_idx) * scale; + float v1 = 0.0f; + if (local_idx + 1 < block_length) { + v1 = load_value(input, global_idx + 1) * scale; + } + uchar packed = quantize_nf4_scalar(v0) << 4; + packed |= quantize_nf4_scalar(v1); + const uint pair_index = global_idx >> 1; + output[pair_index] = packed; + } +} + +template +inline void dequantize_nf4_impl( + const device uchar* input, + const device float* absmax, + device T* output, + constant BlockParams& params, + uint tid, + uint threads_per_group, + uint block_idx) { + + const uint block_start = block_idx * params.blocksize; + if (block_start >= params.n) { + return; } - } + + const uint block_end = min(block_start + params.blocksize, params.n); + const uint block_length = block_end - block_start; + const float block_absmax = absmax[block_idx]; + + const uint total_pairs = (params.n + 1) >> 1; + const uint pair_start = block_start >> 1; + + for (uint local_pair = tid; local_pair < ((block_length + 1) >> 1); local_pair += threads_per_group) { + const uint pair_index = pair_start + local_pair; + if (pair_index >= total_pairs) { + continue; + } + + const uchar packed = input[pair_index]; + const float v0 = dequantize_nf4_scalar(packed >> 4) * block_absmax; + const float v1 = dequantize_nf4_scalar(packed & 0x0F) * block_absmax; + + const uint elem0 = block_start + local_pair * 2; + const uint elem1 = elem0 + 1; + + if (elem0 < params.n) { + store_value(output, elem0, v0); + } + if (elem1 < params.n) { + store_value(output, elem1, v1); + } + } +} + +kernel void quantize_nf4_f16( + const device half* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* output [[buffer(2)]], + constant BlockParams& params [[buffer(3)]], + uint3 tg_pos [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint3 tg_size [[threads_per_threadgroup]]) { + threadgroup float shared_vals[64]; + threadgroup float shared_scale; + threadgroup float shared_absmax; + quantize_nf4_impl( + input, + absmax, + output, + params, + shared_vals, + shared_scale, + shared_absmax, + tid, + tg_size.x, + tg_pos.x); +} + +kernel void quantize_nf4_f32( + const device float* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* output [[buffer(2)]], + constant BlockParams& params [[buffer(3)]], + uint3 tg_pos [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint3 tg_size [[threads_per_threadgroup]]) { + threadgroup float shared_vals[64]; + threadgroup float shared_scale; + threadgroup float shared_absmax; + quantize_nf4_impl( + input, + absmax, + output, + params, + shared_vals, + shared_scale, + shared_absmax, + tid, + tg_size.x, + tg_pos.x); +} + +kernel void dequantize_nf4_f16( + const device uchar* input [[buffer(0)]], + const device float* absmax [[buffer(1)]], + device half* output [[buffer(2)]], + constant BlockParams& params [[buffer(3)]], + uint3 tg_pos [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint3 tg_size [[threads_per_threadgroup]]) { + dequantize_nf4_impl(input, absmax, output, params, tid, tg_size.x, tg_pos.x); +} + +kernel void dequantize_nf4_f32( + const device uchar* input [[buffer(0)]], + const device float* absmax [[buffer(1)]], + device float* output [[buffer(2)]], + constant BlockParams& params [[buffer(3)]], + uint3 tg_pos [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint3 tg_size [[threads_per_threadgroup]]) { + dequantize_nf4_impl(input, absmax, output, params, tid, tg_size.x, tg_pos.x); } diff --git a/csrc/mps_ops.mm b/csrc/mps_ops.mm index 85ed1b1e4..be1ef28d8 100644 --- a/csrc/mps_ops.mm +++ b/csrc/mps_ops.mm @@ -1,9 +1,35 @@ +#import +#import #import -#define HLF_MAX 65504 -#define TH 1024 -#define NUM 4 -#define NUM_BLOCK 4096 +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +struct BlockParams { + uint32_t n; + uint32_t blocksize; + uint32_t threads_per_group; +}; + +struct BufferBinding { + id buffer = nil; + size_t offset = 0; + size_t length = 0; + const void* host_src = nullptr; + void* host_dst = nullptr; + bool owns_buffer = false; +}; static inline MPSGraph* get_graph() { static MPSGraph* cur = nil; @@ -14,49 +40,572 @@ } static inline id get_device() { - NSError* error = nil; - static id device = nil; - if (!device) { - device = MTLCreateSystemDefaultDevice(); + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + if (!stream) { + stream = at::mps::getDefaultMPSStream(); } - if (!device) { - NSLog(@"Failed to get MPS device"); + if (stream) { + return (id)stream->device(); + } + + static id fallback = nil; + static dispatch_once_t once; + dispatch_once(&once, ^{ + fallback = MTLCreateSystemDefaultDevice(); + }); + if (!fallback) { + NSLog(@"bitsandbytes: failed to acquire MPS device"); abort(); } - return device; + return fallback; +} + +static id get_library(); + +static NSString* get_metallib_path() { + static NSString* metallib_path = nil; + static dispatch_once_t once_token; + dispatch_once(&once_token, ^{ + Dl_info info; + if (dladdr(reinterpret_cast(&get_library), &info) && info.dli_fname) { + NSString* dylib_path = [NSString stringWithUTF8String:info.dli_fname]; + NSString* dylib_dir = [dylib_path stringByDeletingLastPathComponent]; + NSString* candidate = [dylib_dir stringByAppendingPathComponent:@"bitsandbytes.metallib"]; + if ([[NSFileManager defaultManager] fileExistsAtPath:candidate]) { + metallib_path = [candidate retain]; + return; + } + } + + PyGILState_STATE gil = PyGILState_Ensure(); + PyObject* module = PyImport_ImportModule("bitsandbytes"); + if (!module) { + PyErr_Clear(); + PyGILState_Release(gil); + return; + } + + PyObject* file_attr = PyObject_GetAttrString(module, "__file__"); + if (!file_attr) { + PyErr_Clear(); + Py_DECREF(module); + PyGILState_Release(gil); + return; + } + + const char* module_path_cstr = PyUnicode_AsUTF8(file_attr); + if (module_path_cstr) { + NSString* module_path = [NSString stringWithUTF8String:module_path_cstr]; + NSString* module_dir = [module_path stringByDeletingLastPathComponent]; + NSString* candidate = [module_dir stringByAppendingPathComponent:@"bitsandbytes.metallib"]; + + if ([[NSFileManager defaultManager] fileExistsAtPath:candidate]) { + metallib_path = [candidate retain]; + } + } else { + PyErr_Clear(); + } + + Py_DECREF(file_attr); + Py_DECREF(module); + PyGILState_Release(gil); + }); + return metallib_path; } static inline id get_library() { - NSError* error = nil; static id library = nil; - if (!library) { - library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; + static dispatch_once_t once; + dispatch_once(&once, ^{ + NSError* error = nil; + id device = get_device(); + NSString* metallib_path = get_metallib_path(); + NSURL* url = metallib_path ? [NSURL fileURLWithPath:metallib_path] + : [NSURL fileURLWithPath:@"bitsandbytes.metallib"]; + library = [device newLibraryWithURL:url error:&error]; + if (!library) { + NSLog(@"bitsandbytes: failed to load bitsandbytes.metallib (%@)", error); + abort(); + } + }); + return library; +} + +static id get_pipeline(NSString* function_name) { + static NSMutableDictionary>* pipelines = nil; + static dispatch_once_t once_token; + dispatch_once(&once_token, ^{ + pipelines = [[NSMutableDictionary alloc] init]; + }); + + id pipeline = pipelines[function_name]; + if (pipeline) { + return pipeline; + } + + NSError* error = nil; + id function = [get_library() newFunctionWithName:function_name]; + if (!function) { + NSLog(@"bitsandbytes: Metal function %@ not found", function_name); + abort(); } - if (!library) { - NSLog(@"Failed to load bitsandbytes.metallib"); + + pipeline = [get_device() newComputePipelineStateWithFunction:function error:&error]; + [function release]; + if (!pipeline) { + NSLog(@"bitsandbytes: failed to create pipeline for %@ (%@)", function_name, error); abort(); } - return library; + + pipelines[function_name] = pipeline; + return pipeline; } -/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) -{ - id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 -dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out; -}*/ - -// MPSGraph function for quantize -extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) { - id device = get_device(); - id library = get_library(); - static id kernel = nil; - if (!kernel) { - kernel = [library newFunctionWithName:@"quantize"]; - if (!kernel) { - NSLog(@"Failed to load bitsandbytes.metallib"); - abort(); +static NSUInteger preferred_threads(id pipeline) { + NSUInteger max_threads = pipeline.maxTotalThreadsPerThreadgroup; + if (max_threads == 0) { + return 64; + } + return std::min(64, max_threads); +} + +static bool dispatch_blockwise_kernel( + NSString* function_name, + BufferBinding& input, + BufferBinding& absmax, + BufferBinding& output, + uint32_t n, + uint32_t blocksize) { + + if (n == 0 || blocksize == 0) { + return true; + } + + id pipeline = get_pipeline(function_name); + const NSUInteger threads_per_group = preferred_threads(pipeline); + const uint32_t num_blocks = (blocksize > 0) ? (n + blocksize - 1) / blocksize : 0; + if (num_blocks == 0) { + return true; + } + + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + if (!stream) { + stream = at::mps::getDefaultMPSStream(); + } + if (!stream) { + PyErr_SetString(PyExc_RuntimeError, "bitsandbytes: failed to acquire current MPS stream"); + return false; + } + + id device = (id)stream->device(); + if (!device) { + PyErr_SetString(PyExc_RuntimeError, "bitsandbytes: failed to acquire MPS device"); + return false; + } + + __block BufferBinding input_binding = input; + __block BufferBinding absmax_binding = absmax; + __block BufferBinding output_binding = output; + + __block bool success = true; + at::native::mps::dispatch_sync_with_rethrow(stream->queue(), ^(){ + @autoreleasepool { + auto prepare_buffer = ^(BufferBinding& binding) { + if (binding.length == 0) { + return; + } + if (!binding.buffer) { + binding.buffer = [device newBufferWithLength:binding.length options:MTLResourceStorageModeShared]; + binding.owns_buffer = true; + binding.offset = 0; + } + if (binding.host_src) { + std::memcpy((uint8_t*)binding.buffer.contents + binding.offset, binding.host_src, binding.length); + } + }; + + prepare_buffer(input_binding); + prepare_buffer(absmax_binding); + prepare_buffer(output_binding); + + id encoder = stream->commandEncoder(); + if (!encoder) { + PyErr_SetString(PyExc_RuntimeError, "bitsandbytes: failed to obtain command encoder"); + success = false; + return; + } + + [encoder setComputePipelineState:pipeline]; + + if (input_binding.buffer) { + [encoder setBuffer:input_binding.buffer offset:input_binding.offset atIndex:0]; + [encoder useResource:input_binding.buffer usage:MTLResourceUsageRead]; + } + if (absmax_binding.buffer) { + [encoder setBuffer:absmax_binding.buffer offset:absmax_binding.offset atIndex:1]; + [encoder useResource:absmax_binding.buffer usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + if (output_binding.buffer) { + [encoder setBuffer:output_binding.buffer offset:output_binding.offset atIndex:2]; + [encoder useResource:output_binding.buffer usage:MTLResourceUsageRead | MTLResourceUsageWrite]; + } + + BlockParams params = {n, blocksize, static_cast(threads_per_group)}; + [encoder setBytes:¶ms length:sizeof(BlockParams) atIndex:3]; + + MTLSize threads = MTLSizeMake(threads_per_group, 1, 1); + MTLSize threadgroups = MTLSizeMake(num_blocks, 1, 1); + [encoder dispatchThreadgroups:threadgroups threadsPerThreadgroup:threads]; } + }); + if (!success) { + return false; + } + + const bool needs_host_read = + (absmax_binding.host_dst && absmax_binding.length) || + (output_binding.host_dst && output_binding.length); + + stream->synchronize(needs_host_read ? at::mps::SyncType::COMMIT_AND_WAIT + : at::mps::SyncType::COMMIT); + + auto copy_back = [](BufferBinding& binding) { + if (binding.buffer && binding.host_dst && binding.length) { + std::memcpy(binding.host_dst, (uint8_t*)binding.buffer.contents + binding.offset, binding.length); + } + if (binding.owns_buffer && binding.buffer) { + [binding.buffer release]; + binding.buffer = nil; + } + }; + + copy_back(input_binding); + copy_back(absmax_binding); + copy_back(output_binding); + return true; +} + +static bool tensor_from_pyobject(PyObject* obj, const char* name, at::Tensor& tensor) { + if (!obj) { + PyErr_Format(PyExc_TypeError, "%s must not be null", name); + return false; + } + if (!THPVariable_Check(obj)) { + PyErr_Format(PyExc_TypeError, "%s must be a torch.Tensor", name); + return false; + } + tensor = THPVariable_Unpack(reinterpret_cast(obj)); + if (tensor.device().type() != c10::DeviceType::MPS) { + PyErr_Format(PyExc_RuntimeError, "%s must be an MPS tensor", name); + return false; + } + if (!tensor.is_contiguous()) { + PyErr_Format(PyExc_RuntimeError, "%s must be contiguous for MPS kernels", name); + return false; } - NSLog(@"Not implemented"); - return nil; + return true; } + +static bool binding_from_tensor(const at::Tensor& tensor, BufferBinding& binding) { + binding.buffer = at::native::mps::getMTLBufferStorage(tensor); + if (binding.buffer == nil) { + PyErr_SetString(PyExc_RuntimeError, "bitsandbytes: tensor does not have an associated MTLBuffer"); + return false; + } + binding.offset = static_cast(tensor.storage_offset()) * tensor.element_size(); + binding.length = static_cast(tensor.numel()) * tensor.element_size(); + binding.owns_buffer = false; + return true; +} + +static BufferBinding binding_from_host(const void* ptr, size_t length, bool copy_to_device, bool retrieve_from_device) { + BufferBinding binding; + binding.length = length; + if (copy_to_device) { + binding.host_src = ptr; + } + if (retrieve_from_device) { + binding.host_dst = const_cast(ptr); + } + return binding; +} + +} // namespace + +extern "C" { + +// Pointer-based entry points (used for CPU fallback / legacy paths) +void cquantize_blockwise_fp16_nf4( + float* /*code*/, + void* A, + float* absmax, + unsigned char* out, + int blocksize, + int n) { + if (!A || !absmax || !out || n <= 0 || blocksize <= 0) { + return; + } + const size_t absmax_blocks = static_cast((n + blocksize - 1) / blocksize); + BufferBinding input_binding = binding_from_host(A, static_cast(n) * sizeof(uint16_t), true, false); + BufferBinding absmax_binding = binding_from_host(absmax, absmax_blocks * sizeof(float), false, true); + BufferBinding output_binding = binding_from_host(out, static_cast((n + 1) / 2), false, true); + if (!dispatch_blockwise_kernel( + @"quantize_nf4_f16", + input_binding, + absmax_binding, + output_binding, + static_cast(n), + static_cast(blocksize))) { + return; + } +} + +void cquantize_blockwise_fp32_nf4( + float* /*code*/, + float* A, + float* absmax, + unsigned char* out, + int blocksize, + int n) { + if (!A || !absmax || !out || n <= 0 || blocksize <= 0) { + return; + } + const size_t absmax_blocks = static_cast((n + blocksize - 1) / blocksize); + BufferBinding input_binding = binding_from_host(A, static_cast(n) * sizeof(float), true, false); + BufferBinding absmax_binding = binding_from_host(absmax, absmax_blocks * sizeof(float), false, true); + BufferBinding output_binding = binding_from_host(out, static_cast((n + 1) / 2), false, true); + if (!dispatch_blockwise_kernel( + @"quantize_nf4_f32", + input_binding, + absmax_binding, + output_binding, + static_cast(n), + static_cast(blocksize))) { + return; + } +} + +void cdequantize_blockwise_fp16_nf4( + float* /*code*/, + unsigned char* A, + float* absmax, + void* out, + int blocksize, + int n) { + if (!A || !absmax || !out || n <= 0 || blocksize <= 0) { + return; + } + const size_t absmax_blocks = static_cast((n + blocksize - 1) / blocksize); + BufferBinding input_binding = binding_from_host(A, static_cast((n + 1) / 2), true, false); + BufferBinding absmax_binding = binding_from_host(absmax, absmax_blocks * sizeof(float), true, false); + BufferBinding output_binding = binding_from_host(out, static_cast(n) * sizeof(uint16_t), false, true); + if (!dispatch_blockwise_kernel( + @"dequantize_nf4_f16", + input_binding, + absmax_binding, + output_binding, + static_cast(n), + static_cast(blocksize))) { + return; + } +} + +void cdequantize_blockwise_fp32_nf4( + float* /*code*/, + unsigned char* A, + float* absmax, + float* out, + int blocksize, + int n) { + if (!A || !absmax || !out || n <= 0 || blocksize <= 0) { + return; + } + const size_t absmax_blocks = static_cast((n + blocksize - 1) / blocksize); + BufferBinding input_binding = binding_from_host(A, static_cast((n + 1) / 2), true, false); + BufferBinding absmax_binding = binding_from_host(absmax, absmax_blocks * sizeof(float), true, false); + BufferBinding output_binding = binding_from_host(out, static_cast(n) * sizeof(float), false, true); + if (!dispatch_blockwise_kernel( + @"dequantize_nf4_f32", + input_binding, + absmax_binding, + output_binding, + static_cast(n), + static_cast(blocksize))) { + return; + } +} + +// Tensor-aware entry points (used from Python to avoid extra copies) +void cquantize_blockwise_fp16_nf4_tensor(PyObject* A_obj, PyObject* absmax_obj, PyObject* out_obj, int blocksize) { + at::Tensor A; + at::Tensor absmax; + at::Tensor out; + if (!tensor_from_pyobject(A_obj, "A", A) || + !tensor_from_pyobject(absmax_obj, "absmax", absmax) || + !tensor_from_pyobject(out_obj, "out", out)) { + return; + } + + if (A.scalar_type() != at::kHalf) { + PyErr_SetString(PyExc_TypeError, "A must be float16 for NF4 quantization"); + return; + } + if (absmax.scalar_type() != at::kFloat) { + PyErr_SetString(PyExc_TypeError, "absmax must be float32"); + return; + } + if (out.scalar_type() != at::kByte) { + PyErr_SetString(PyExc_TypeError, "out must be uint8"); + return; + } + + BufferBinding input_binding; + BufferBinding absmax_binding; + BufferBinding output_binding; + if (!binding_from_tensor(A, input_binding) || + !binding_from_tensor(absmax, absmax_binding) || + !binding_from_tensor(out, output_binding)) { + return; + } + + if (!dispatch_blockwise_kernel( + @"quantize_nf4_f16", + input_binding, + absmax_binding, + output_binding, + static_cast(A.numel()), + static_cast(blocksize))) { + return; + } +} + +void cquantize_blockwise_fp32_nf4_tensor(PyObject* A_obj, PyObject* absmax_obj, PyObject* out_obj, int blocksize) { + at::Tensor A; + at::Tensor absmax; + at::Tensor out; + if (!tensor_from_pyobject(A_obj, "A", A) || + !tensor_from_pyobject(absmax_obj, "absmax", absmax) || + !tensor_from_pyobject(out_obj, "out", out)) { + return; + } + + if (A.scalar_type() != at::kFloat) { + PyErr_SetString(PyExc_TypeError, "A must be float32 for NF4 quantization"); + return; + } + if (absmax.scalar_type() != at::kFloat) { + PyErr_SetString(PyExc_TypeError, "absmax must be float32"); + return; + } + if (out.scalar_type() != at::kByte) { + PyErr_SetString(PyExc_TypeError, "out must be uint8"); + return; + } + + BufferBinding input_binding; + BufferBinding absmax_binding; + BufferBinding output_binding; + if (!binding_from_tensor(A, input_binding) || + !binding_from_tensor(absmax, absmax_binding) || + !binding_from_tensor(out, output_binding)) { + return; + } + + if (!dispatch_blockwise_kernel( + @"quantize_nf4_f32", + input_binding, + absmax_binding, + output_binding, + static_cast(A.numel()), + static_cast(blocksize))) { + return; + } +} + +void cdequantize_blockwise_fp16_nf4_tensor(PyObject* A_obj, PyObject* absmax_obj, PyObject* out_obj, int blocksize) { + at::Tensor A; + at::Tensor absmax; + at::Tensor out; + if (!tensor_from_pyobject(A_obj, "A", A) || + !tensor_from_pyobject(absmax_obj, "absmax", absmax) || + !tensor_from_pyobject(out_obj, "out", out)) { + return; + } + + if (A.scalar_type() != at::kByte) { + PyErr_SetString(PyExc_TypeError, "A must be uint8 for NF4 dequantization"); + return; + } + if (absmax.scalar_type() != at::kFloat) { + PyErr_SetString(PyExc_TypeError, "absmax must be float32"); + return; + } + if (out.scalar_type() != at::kHalf) { + PyErr_SetString(PyExc_TypeError, "out must be float16"); + return; + } + + BufferBinding input_binding; + BufferBinding absmax_binding; + BufferBinding output_binding; + if (!binding_from_tensor(A, input_binding) || + !binding_from_tensor(absmax, absmax_binding) || + !binding_from_tensor(out, output_binding)) { + return; + } + + if (!dispatch_blockwise_kernel( + @"dequantize_nf4_f16", + input_binding, + absmax_binding, + output_binding, + static_cast(out.numel()), + static_cast(blocksize))) { + return; + } +} + +void cdequantize_blockwise_fp32_nf4_tensor(PyObject* A_obj, PyObject* absmax_obj, PyObject* out_obj, int blocksize) { + at::Tensor A; + at::Tensor absmax; + at::Tensor out; + if (!tensor_from_pyobject(A_obj, "A", A) || + !tensor_from_pyobject(absmax_obj, "absmax", absmax) || + !tensor_from_pyobject(out_obj, "out", out)) { + return; + } + + if (A.scalar_type() != at::kByte) { + PyErr_SetString(PyExc_TypeError, "A must be uint8 for NF4 dequantization"); + return; + } + if (absmax.scalar_type() != at::kFloat) { + PyErr_SetString(PyExc_TypeError, "absmax must be float32"); + return; + } + if (out.scalar_type() != at::kFloat) { + PyErr_SetString(PyExc_TypeError, "out must be float32"); + return; + } + + BufferBinding input_binding; + BufferBinding absmax_binding; + BufferBinding output_binding; + if (!binding_from_tensor(A, input_binding) || + !binding_from_tensor(absmax, absmax_binding) || + !binding_from_tensor(out, output_binding)) { + return; + } + + if (!dispatch_blockwise_kernel( + @"dequantize_nf4_f32", + input_binding, + absmax_binding, + output_binding, + static_cast(out.numel()), + static_cast(blocksize))) { + return; + } +} + +} // extern "C" diff --git a/pyproject.toml b/pyproject.toml index 748b77d90..f55b8ea31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ test = [ ] [tool.setuptools] -package-data = { "*" = ["libbitsandbytes*.*", "py.typed"] } +package-data = { "*" = ["libbitsandbytes*.*", "*.metallib", "py.typed"] } [tool.setuptools.packages.find] include = ["bitsandbytes*"] diff --git a/scripts/mps_nf4_test.py b/scripts/mps_nf4_test.py new file mode 100644 index 000000000..795c59bcb --- /dev/null +++ b/scripts/mps_nf4_test.py @@ -0,0 +1,6 @@ +import torch +import bitsandbytes as bnb +A = torch.randn(128, device='mps', dtype=torch.float16) +q, state = bnb.functional.quantize_4bit(A, quant_type='nf4') +A2 = bnb.functional.dequantize_4bit(q, quant_state=state) +print('diff', float((A2 - A).abs().mean().cpu()), flush=True)