From c9fec32d6ef2d53a0d02f5b0ca39d87e184e4447 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Thu, 15 May 2025 22:50:34 +0000 Subject: [PATCH 01/98] Port ROCm changes from multi-backend-refactor branch --- CMakeLists.txt | 80 +- bitsandbytes/cextension.py | 82 +- bitsandbytes/diagnostics/cuda.py | 70 +- bitsandbytes/diagnostics/main.py | 26 +- csrc/common_hip.cuh | 7 + csrc/kernels.hip | 3253 ++++++++++++++++++++++++++++++ csrc/kernels_hip.cuh | 132 ++ csrc/ops.hip | 836 ++++++++ csrc/ops_hip.cuh | 195 ++ csrc/pythonInterface.cpp | 22 +- 10 files changed, 4654 insertions(+), 49 deletions(-) create mode 100644 csrc/common_hip.cuh create mode 100644 csrc/kernels.hip create mode 100644 csrc/kernels_hip.cuh create mode 100644 csrc/ops.hip create mode 100644 csrc/ops_hip.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b462c45d..8a7583279 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,13 +25,14 @@ endif() # Define included source files set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() set(BUILD_CUDA ON) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "hip") + if(APPLE) + message(FATAL_ERROR "HIP is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP ON) set(BUILD_MPS OFF) elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) endif() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS ON) else() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS OFF) endif() @@ -160,6 +171,36 @@ if(BUILD_CUDA) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) +elseif(BUILD_HIP) + enable_language(HIP) + message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") + if(DEFINED BNB_ROCM_ARCH) + set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH}) + else() + if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100") + elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + endif() + endif() + message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}") + + list(APPEND SRC_FILES ${HIP_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_rocm") + + # get hip version + execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) + string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") + string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") + + string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") + if(HIP_VERSION VERSION_LESS "6.1") + string(APPEND BNB_OUTPUT_NAME "_nohipblaslt") + endif() + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + add_compile_definitions(BUILD_HIP) elseif(BUILD_MPS) if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -208,6 +249,41 @@ if(BUILD_CUDA) CUDA_SEPARABLE_COMPILATION ON ) endif() +if(BUILD_HIP) + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + endmacro() + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + + ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) + target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) + set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) + set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) + + if(HIP_VERSION VERSION_LESS "6.1") + target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT) + else() + find_package(hipblaslt) + target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt) + endif() +endif() if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 3fb8db26f..c8b02fb22 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -22,11 +22,17 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: """ prefix = "rocm" if torch.version.hip else "cuda" - library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" + blas_suffix = "_nohipblaslt" if torch.version.hip and cuda_specs.cuda_version_tuple < (6, 1) else "" + library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{blas_suffix}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) + if torch.version.hip: + raise RuntimeError( + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" + ) logger.warning( f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" @@ -72,10 +78,11 @@ def __init__(self, lib: ct.CDLL): def get_available_cuda_binary_versions() -> list[str]: """Get formatted CUDA versions from existing library files using cuda_specs logic""" - lib_pattern = f"libbitsandbytes_cuda*{DYNAMIC_LIBRARY_SUFFIX}" + lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" versions = [] for lib in Path(__file__).parent.glob(lib_pattern): - match = re.search(r"cuda(\d{3})", lib.name) + pattern = r"{}(\d+)".format(BNB_BACKEND.lower()) + match = re.search(pattern, lib.name) if match: ver_code = int(match.group(1)) major = ver_code // 10 @@ -86,8 +93,8 @@ def get_available_cuda_binary_versions() -> list[str]: def parse_cuda_version(version_str: str) -> str: """Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')""" - if version_str.isdigit() and len(version_str) == 3: - return f"{version_str[:2]}.{version_str[2]}" + if version_str.isdigit(): + return f"{version_str[:-1]}.{version_str[-1]}" return version_str # fallback as safety net @@ -148,7 +155,7 @@ def _format_lib_error_message( """Format detailed error message for library loading failures""" analysis = "" no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error - no_cuda_lib_found = "CUDA binary not found" in original_error + no_cuda_lib_found = f"{BNB_BACKEND} binary not found" in original_error if no_cpu_lib_found: analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n" @@ -157,9 +164,9 @@ def _format_lib_error_message( version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE" analysis = ( ( - f"\n🚨 CUDA VERSION MISMATCH 🚨\n" - f"Requested CUDA version: {requested_version}\n" - f"Detected PyTorch CUDA version: {user_cuda_version}\n" + f"\n🚨 {BNB_BACKEND} VERSION MISMATCH 🚨\n" + f"Requested {BNB_BACKEND} version: {requested_version}\n" + f"Detected PyTorch {BNB_BACKEND} version: {user_cuda_version}\n" f"Available pre-compiled versions: {version_list_str}\n\n" "This means:\n" "The version you're trying to use is NOT distributed with this package\n\n" @@ -174,42 +181,49 @@ def _format_lib_error_message( troubleshooting = ( ( - "This typically happens when:\n" - "1. bitsandbytes doesn't ship with a pre-compiled binary for your CUDA version\n" - "2. The library wasn't compiled properly during installation from source\n\n" + f"This typically happens when:\n" + f"1. bitsandbytes doesn't ship with a pre-compiled binary for your {BNB_BACKEND} version\n" + f"2. The library wasn't compiled properly during installation from source\n\n" ) if no_cuda_lib_found - else "This typically happens when you checked the code out from source and your torch installation doesn't detect CUDA on your machine.\n\n" + else f"This typically happens when you checked the code out from source and your torch installation doesn't detect {BNB_BACKEND} on your machine.\n\n" ) note = ( ( - "To make bitsandbytes work, the compiled library version MUST exactly match the linked CUDA version.\n" - "If your CUDA version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" + f"To make bitsandbytes work, the compiled library version MUST exactly match the linked {BNB_BACKEND} version.\n" + f"If your {BNB_BACKEND} version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" ) if no_cuda_lib_found else "" ) compile_instructions = ( + ( + "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" + ) if not no_cuda_lib_found + else ( "You have two options:\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" + ) if not HIP_ENVIRONMENT + else + ( + "You can COMPILE FROM SOURCE as mentioned here:\n" + " https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n" ) - if no_cuda_lib_found - else "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" ) diagnostics = ( - "šŸ” Run this command for detailed diagnostics:\n" - "python -m bitsandbytes\n\n" - "If you've tried everything and still have issues:\n" - "1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" - "2. Describe what you've tried in detail\n" - "3. Open an issue with this information:\n" - " https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" + f"šŸ” Run this command for detailed diagnostics:\n" + f"python -m bitsandbytes\n\n" + f"If you've tried everything and still have issues:\n" + f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n" + f"2. Describe what you've tried in detail\n" + f"3. Open an issue with this information:\n" + f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" ) return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}" @@ -224,18 +238,19 @@ def _format_dependency_error(self) -> str: ) return ( - f"\n🚨 CUDA SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" - f"CUDA {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" + f"\n🚨 {BNB_BACKEND} SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" + f"{BNB_BACKEND} {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" f"To fix this, make sure that:\n" - f"1. You have installed CUDA {cuda_major_version}.x toolkit on your system\n" - f"2. The CUDA runtime libraries are in your LD_LIBRARY_PATH\n\n" + f"1. You have installed {BNB_BACKEND} {cuda_major_version}.x toolkit on your system\n" + f"2. The {BNB_BACKEND} runtime libraries are in your LD_LIBRARY_PATH\n\n" f"You can add them with (and persist the change by adding the line to your .bashrc):\n" - f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/cuda-{cuda_major_version}.x/lib64\n\n" + f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/{BNB_BACKEND.lower()}-{cuda_major_version}.x/\ + {'lib64' if not HIP_ENVIRONMENT else 'lib'}\n\n" f"Original error: {self.error_msg}\n\n" f"šŸ” Run this command for detailed diagnostics:\n" f"python -m bitsandbytes\n\n" f"If you've tried everything and still have issues:\n" - f"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" + f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n" f"2. Describe what you've tried in detail\n" f"3. Open an issue with this information:\n" f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" @@ -264,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary: cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) if not cuda_binary_path.exists(): - raise RuntimeError(f"Configured CUDA binary not found at {cuda_binary_path}") + raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}") binary_path = cuda_binary_path @@ -284,6 +299,11 @@ def get_native_library() -> BNBNativeLibrary: try: + if torch.version.hip: + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" + else: + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" + lib = get_native_library() except Exception as e: error_msg = str(e) diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index affcb0ae6..b9de27fd7 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.diagnostics.utils import print_dedented @@ -33,6 +33,8 @@ } CUDA_RUNTIME_LIB_PATTERNS = ( + "libamdhip64.so*", +) if HIP_ENVIRONMENT else ( "cudart64*.dll", # Windows "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. "nvcuda*.dll", # Windows @@ -57,7 +59,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path pass for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: for pth in dir.glob(lib_pattern): - if pth.is_file(): + if pth.is_file() and not pth.is_symlink(): yield pth except (OSError, PermissionError): pass @@ -104,7 +106,7 @@ def find_cudart_libraries() -> Iterator[Path]: yield from find_cuda_libraries_in_path_list(value) -def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print( f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", @@ -149,7 +151,37 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: # (2) Multiple CUDA versions installed -def print_cuda_runtime_diagnostics() -> None: +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCm version + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) + + hip_major, hip_minor = cuda_specs.cuda_version_tuple + if (hip_major, hip_minor) < (6, 1): + print_dedented( + """ + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + """, + ) + + +def print_diagnostics(cuda_specs: CUDASpecs) -> None: + if HIP_ENVIRONMENT: + _print_hip_diagnostics(cuda_specs) + else: + _print_cuda_diagnostics(cuda_specs) + + +def _print_cuda_runtime_diagnostics() -> None: cudart_paths = list(find_cudart_libraries()) if not cudart_paths: print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") @@ -174,3 +206,33 @@ def print_cuda_runtime_diagnostics() -> None: ) for pth in cudart_paths: print(f"* Found CUDA runtime at: {pth}") + + +def _print_hip_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("WARNING! ROCm runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate ROCm runtime files (see below). + + We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + but this might mismatch with the ROCm version that is needed for bitsandbytes. + + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + """, + ) + + for pth in cudart_paths: + print(f"* Found ROCm runtime at: {pth}") + + +def print_runtime_diagnostics() -> None: + if HIP_ENVIRONMENT: + _print_hip_runtime_diagnostics() + else: + _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index b6236d668..8e2bc2a7b 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -3,11 +3,12 @@ import torch +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.diagnostics.cuda import ( - print_cuda_diagnostics, - print_cuda_runtime_diagnostics, + print_diagnostics, + print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header @@ -34,19 +35,24 @@ def main(): print_header("OTHER") cuda_specs = get_cuda_specs() - print("CUDA specs:", cuda_specs) + if HIP_ENVIRONMENT: + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + print(f"{BNB_BACKEND} specs:{rocm_specs}") + else: + print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): - print("Torch says CUDA is not available. Possible reasons:") - print("1. CUDA driver not installed") - print("2. CUDA not installed") - print("3. You have multiple conflicting CUDA libraries") + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") + if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") + print(f"- {BNB_BACKEND} not installed") + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: - print_cuda_diagnostics(cuda_specs) - print_cuda_runtime_diagnostics() + print_diagnostics(cuda_specs) + print_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") print_header("") - print("Checking that the library is importable and CUDA is callable...") + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") try: sanity_check() print("SUCCESS!") diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh new file mode 100644 index 000000000..e7fc4eb81 --- /dev/null +++ b/csrc/common_hip.cuh @@ -0,0 +1,7 @@ +#pragma once + +#define BNB_WARP_SIZE warpSize + +// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs +#define BNB_MAX_THREADS_PER_SM 2048 +#define BNB_BF16_AVAILABLE true diff --git a/csrc/kernels.hip b/csrc/kernels.hip new file mode 100644 index 000000000..368788f39 --- /dev/null +++ b/csrc/kernels.hip @@ -0,0 +1,3253 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "kernels_hip.cuh" +#include "common_hip.cuh" +#include +#include +#include + +//#include + + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +// Luckily we have atomicmax and atomicmin in ROCm + + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + + +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabsf(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabsf(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) +{ + const int tid = threadIdx.x + (blockDim.x*blockIdx.x); + const int numThreads = blockDim.x*gridDim.x; + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (index1[i]*maxidx1) + index2[i]; + atomicAdd(&histogram[idx], src[i]); + } +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + +template +__launch_bounds__(THREADS_ESTIMATE, 1) +__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; + const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + T vals[NUM_ESTIMATE]; + + typedef hipcub::BlockRadixSort BlockRadixSort; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ union { + typename LoadFloat::TempStorage loadf; + typename BlockRadixSort::TempStorage sort; + int smem_qidx[BLOCK_ESTIMATE]; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + + __syncthreads(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + + __syncthreads(); + for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) + temp_storage.smem_qidx[j] = -1; + + __syncthreads(); + + if(threadIdx.x < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); + temp_storage.smem_qidx[local_idx] = threadIdx.x; + } + + __syncthreads(); + + for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) + { + if(temp_storage.smem_qidx[i] != -1) + atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + } + } +} + + +__launch_bounds__(TH, 4) +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; + + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) +{ + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; + + for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); + + if(threadIdx.x == 0) { + smem_absmax_value[0] = 1.0f / local_abs_max; + absmax[i / BLOCK_SIZE] = local_abs_max; + } + __syncthreads(); + + local_abs_max = smem_absmax_value[0]; + + if(STOCHASTIC) + { + local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + } +} + +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockStore 0) ? 2 : 1), hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + if (DATA_TYPE > 0) + { + valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); + valid_items_store = min(TILE_SIZE * 2, n - i * 2); + } + else + { + valid_items_load = min(TILE_SIZE, n - i); + valid_items_store = valid_items_load; + } + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch (DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + } +} + +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) +{ + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + __shared__ float smem_code[256]; + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + } + + __syncthreads(); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - powf(beta1, step)); + const float correction2 = 1.0f/(1.0f - powf(beta2, step)); + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + + + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + // AdEMAMix has an additional state buffer, which we packed + // into state1. We need thread-local storage here for these. + // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. + float s3_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + // Load additional state1 data for AdEMAMix + // TODO: Make constexpr after updating min compiler + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADEMAMIX: + // m1 update: m1 = beta1 * m1 + (1-beta1) * g + s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); + + // m2 update: m2 = m2 * beta3 + (1-beta3) * g + s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]); + + // nu update: nu = beta2 * nu + (1-beta2) * g^2 + s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); + + p_vals[j] = (float)p_vals[j] - lr * ( + ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( + (sqrtf(s2_vals[j]) / correction2) + eps + ) + ); + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + + break; + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); + } + } +} + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if(threadIdx.x < 256) + { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items); + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + } + + if(threadIdx.x == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS2, 1) +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 512) + { + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case ADAGRAD: + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + } + +} + +template +__global__ void +__launch_bounds__(1024, 1) +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER){ + case ADAGRAD: + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if(threadIdx.x == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } + else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit2StateBlockwise( + T* p, + T* __restrict__ const g, + unsigned char* state1, + unsigned char* state2, + const float beta1, + const float beta2, + const float beta3, + const float alpha, + const float eps, + const int step, + const float lr, + float* __restrict__ const quantiles1, + float* __restrict__ const quantiles2, + float* absmax1, + float* absmax2, + float weight_decay, + const float gnorm_scale, + const bool skip_zeros, + const int n +) { + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + float s3_vals[N_PER_TH]; + + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f -__powf(beta2, step)); + const float step_size = __fdividef(-lr*correction2,correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float new_local_abs_max3 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + unsigned char c3s[N_PER_TH]; + + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + typedef hipcub::BlockReduce BlockReduce2; + typedef hipcub::BlockReduce BlockReduce3; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ typename BlockReduce2::TempStorage reduce3; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + __shared__ float smem_exchange3[1]; // [[maybe_unused]] + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + // AdEMAMix has an additional state packed into state1. + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); + } + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + new_local_abs_max3 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + if (OPTIMIZER == ADEMAMIX) { + // The absmax for the third state is appended to absmax1 + s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE]; + s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); + } + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + + if (OPTIMIZER == ADEMAMIX) { + s3_vals[j] = 0.0f; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); + } + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, hipcub::Max()); + } + + if(threadIdx.x == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + smem_exchange3[0] = new_local_abs_max3; + } + } + + __syncthreads(); + + if(threadIdx.x == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3; + } + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = smem_exchange3[0]; + } + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + if (OPTIMIZER == ADEMAMIX) { + p_vals[j] = T((float)p_vals[j] - lr * ( + ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( + (sqrtf(s2_vals[j]) / correction2) + eps + ) + )); + } else { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + } + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + if (OPTIMIZER == ADEMAMIX) { + c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3)); + + if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { + c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; + } + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); + } + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + + if(threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +// Inputs: +// A [rows, cols] +// Outputs: +// rowStats [rows] +// out [rows, cols] +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { + + // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. + // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE + using TReduction = T; +#else + using TReduction = float; +#endif + + using BlockReduceT = hipcub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + __shared__ TReduction smem_row_absmax; + + const int row_id = blockIdx.x; + const T* row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + TReduction row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const TReduction absval = fabsf(__ldcs(&(row_data[i]))); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = smem_row_absmax = row_absmax; + } + __syncthreads(); + + // Quantize row-wise. + const float scale = __fdividef(127.0f, smem_row_absmax); + for (int i = threadIdx.x; i < cols; i += THREADS) { + float val = row_data[i]; + + if constexpr (SPARSE_DECOMP) { + // For sparse decomposition, we do not want to quantize the outliers. + // Instead they're zeroed out. + out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; + } else { + out[row_id * cols + i] = __float2int_rn(val * scale); + } + } +} + +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { + using BlockReduceT = hipcub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(row_data[i]); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = row_absmax; + } +} + +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); + +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + half *out, + half *__restrict__ const bias, + const int numRows, + const int numCols, + const int n +) { + const int n_out = numRows * numCols; + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + + float local_rowStats[ITEMS_PER_THREAD]; + float local_colStats[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + + typedef hipcub::BlockLoad LoadInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + + int row_idx, col_idx; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); + } + + // Each block loads THREADS * ITEMS_PER_THREAD values from A + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out + ? THREADS * ITEMS_PER_THREAD + : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + local_output[j] = __float2half( + fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) + ); + } + + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int outIdx = block_offset + thread_offset + j; + if (outIdx < n_out) { + out[outIdx] = local_output[j]; + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +#define WARP_SIZE warpSize +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int warp_idx = threadIdx.x % WARP_SIZE; + const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*WARP_SIZE apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + +#define WARPS 3 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + + +template __device__ void printnonzero(T *A, int num_values, const char * strval) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%s %i %f\n", strval, i, (float)A[i]); +} + +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + //// element-wise kernel + //// 1. Load batch x k into registers + //// 2. Load k x k into registers + //// 3. dequantize and store in second pair of k x k + //// 4. matmul + //// 5. sum with cub + //// 6. store outputs + //// TC kernel + //// use k warps per thread block + //// 1. threadblock use read-only cache to read in register tile for A into shared memory + //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments + //// 3. each warp reads a segment of values 16x32 from B + //// 4. do dequantization from register of B into second pair of registers + //// 5. store (4) into fragment + //// 6. matmul aggregate into fragment C + //// 7. aggregate files of C into shared memory block C + //// 8. sum (7) + //// 9. write outputs to matmul output matrix +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; +#endif +} + +// No of 4bit values processed by each thread +#define num_values_4bit 32 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; + + const int warp_idx = threadIdx.x / warpSize; + const int warp_lane = threadIdx.x % warpSize; + const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; + const int offset_B = ldb*row_B; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + if (threadIdx.x < 16) + quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); + //for(int i = threadIdx.x; i < 16; i++) + //quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [M, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) + { + const int inner_idx_halved = inner_idx/2; + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); + + local_absmax = __ldg(&(absmax[absidx])); + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for(int i = 0; i < 4; i++) + { + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) + { + #if BNB_BF16_AVAILABLE + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if BNB_BF16_AVAILABLE + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = T(local_C); + +} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); +template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, hip_bfloat16) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(MOMENTUM, hip_bfloat16) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(RMSPROP, hip_bfloat16) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, hip_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) +MAKE_Optimizer32bit1State(ADAGRAD, hip_bfloat16) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, hip_bfloat16) + +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) +MAKE_PreconditionStatic8bit1State(ADAGRAD, half) +MAKE_PreconditionStatic8bit1State(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) +MAKE_optimizerStatic8bit1State(ADAGRAD, half) +MAKE_optimizerStatic8bit1State(ADAGRAD, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); +template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, const float beta3, const float alpha, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, hip_bfloat16, 256, 1) + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, hip_bfloat16, 256, 1) + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh new file mode 100644 index 000000000..2895012f8 --- /dev/null +++ b/csrc/kernels_hip.cuh @@ -0,0 +1,132 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#ifndef kernels +#define kernels + + +template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); + +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); + +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, + const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n); + + + +template +__global__ void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n); + +template __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); + +template __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n); + + +template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); + +template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/ops.hip b/csrc/ops.hip new file mode 100644 index 000000000..4d077d19a --- /dev/null +++ b/csrc/ops.hip @@ -0,0 +1,836 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#ifndef NO_HIPBLASLT +#include +#endif +#include +#include +#include +#include + +#define ERR_NOT_IMPLEMENTED 100 + +using namespace BinSearch; +using std::cout; +using std::endl; + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kHistogramScatterAdd2D), dim3(num_blocks), dim3(512), 0, 0, histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); + hipLaunchKernelGGL(( kEstimateQuantiles), dim3(num_blocks), dim3(512), 0, 0, A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kDequantize), dim3(num_blocks), dim3(1024), 0, stream, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if(blocksize == 4096) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); + //else if(blocksize == 64) + // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, hipStream_t stream) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize/2, n); + else + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, + const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + case ADEMAMIX: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update after the parameter update + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + break; + } +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 256 +#define NUM_2STATE 1 +#define BLOCKSIZE_1STATE 256 +#define NUM_1STATE 1 + +template void optimizerStatic8bitBlockwise( + T* p, + T* g, + unsigned char* state1, + unsigned char* state2, + float beta1, + float beta2, + float beta3, + float alpha, + float eps, + int step, + float lr, + float* quantiles1, + float* quantiles2, + float* absmax1, + float* absmax2, + float weight_decay, + const float gnorm_scale, + bool skip_zeros, + int n +) { + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + case ADEMAMIX: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + } +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, + C, HIPBLAS_R_32I, ldc, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, + C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + +#ifdef NO_HIPBLASLT +#else +template hipblasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return HIPBLASLT_ORDER_ROW; + break; + case COL: + return HIPBLASLT_ORDER_COL; + break; + case COL32: + //return HIPBLASLT_ORDER_COL32; + return HIPBLASLT_ORDER_COL; + break; + case COL_TURING: + //return HIPBLASLT_ORDER_COL4_4R2_8C; + return HIPBLASLT_ORDER_COL; + break; + case COL_AMPERE: + //return HIPBLASLT_ORDER_COL32_2R_4R4; + return HIPBLASLT_ORDER_COL; + break; + default: + break; + } + + return HIPBLASLT_ORDER_ROW; +} + +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +#endif + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + default: + return dim1; + break; + /*case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + */ + } +} + +static std::string hipError_to_string(const hipError_t ret) +{ + switch(ret) + { + case hipSuccess: + return "hipSuccess"; + case hipErrorInvalidContext: + return "hipErrorInvalidContext"; + case hipErrorInvalidKernelFile: + return "hipErrorInvalidKernelFile"; + case hipErrorMemoryAllocation: + return "hipErrorMemoryAllocation"; + case hipErrorInitializationError: + return "hipErrorInitializationError"; + case hipErrorLaunchFailure: + return "hipErrorLaunchFailure"; + case hipErrorLaunchOutOfResources: + return "hipErrorLaunchOutOfResources"; + case hipErrorInvalidDevice: + return "hipErrorInvalidDevice"; + case hipErrorInvalidValue: + return "hipErrorInvalidValue"; + case hipErrorInvalidDevicePointer: + return "hipErrorInvalidDevicePointer"; + case hipErrorInvalidMemcpyDirection: + return "hipErrorInvalidMemcpyDirection"; + case hipErrorUnknown: + return "hipErrorUnknown"; + case hipErrorInvalidResourceHandle: + return "hipErrorInvalidResourceHandle"; + case hipErrorNotReady: + return "hipErrorNotReady"; + case hipErrorNoDevice: + return "hipErrorNoDevice"; + case hipErrorPeerAccessAlreadyEnabled: + return "hipErrorPeerAccessAlreadyEnabled"; + case hipErrorPeerAccessNotEnabled: + return "hipErrorPeerAccessNotEnabled"; + case hipErrorRuntimeMemory: + return "hipErrorRuntimeMemory"; + case hipErrorRuntimeOther: + return "hipErrorRuntimeOther"; + case hipErrorHostMemoryAlreadyRegistered: + return "hipErrorHostMemoryAlreadyRegistered"; + case hipErrorHostMemoryNotRegistered: + return "hipErrorHostMemoryNotRegistered"; + case hipErrorMapBufferObjectFailed: + return "hipErrorMapBufferObjectFailed"; + case hipErrorTbd: + return "hipErrorTbd"; + default: + throw std::runtime_error("unknown hipError"); + } +} + +template int igemmlt( + hipblasLtHandle_t ltHandle, + int m, int n, int k, + const int8_t *A, + const int8_t *B, + void *C, + float *row_scale, + int lda, int ldb, int ldc, + hipStream_t stream +) { +#ifdef NO_HIPBLASLT + return ERR_NOT_IMPLEMENTED; +#else + + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + + int has_error = 0; + const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel + + hipblasLtMatmulDesc_t matmulDesc; + hipblasLtMatrixLayout_t aDesc, bDesc, cDesc; + hipblasOperation_t opT = HIPBLAS_OP_T; + + hipDataType outType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_8I; + hipDataType scaleType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_32F; + + hipblasLtPointerMode_t pointerMode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&aDesc, HIP_R_8I, m, k, lda)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&bDesc, HIP_R_8I, m, n, ldb)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + + // Default layout order is col major + + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, scaleType)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + + /* Algo and workspace TODO: need to rework to not be duplicated */ + // Set User Preference attributes + hipblasLtMatmulPreference_t pref; + checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); + checkHipblasStatus( + hipblasLtMatmulPreferenceSetAttribute(pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, + matmulDesc, + aDesc, + bDesc, + cDesc, + cDesc, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if (returnedAlgoCount == 0) + { + has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); + } else { + int alpha = 1, beta = 0; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int32_t*)C, cDesc, + (int32_t*)C, cDesc, + &heuristicResult[0].algo, NULL, 0, stream + )); + } + } else { + // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. + + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } else { + hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + float beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute( + matmulDesc, + HIPBLASLT_MATMUL_DESC_POINTER_MODE, + &pointerMode, + sizeof(alphaVec) + )); + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + row_scale, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } + } + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(cDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(bDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(aDesc)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); + + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif // NO_HIPBLASLT +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, hipStream_t stream) +{ + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows*numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + hipLaunchKernelGGL(( kdequant_mm_int32_fp16), dim3(num_blocks), dim3(threads), 0, stream, A, rowStats, colStats, out, bias, numRows, numCols, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { + if (threshold == 0.0) + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + else + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + +#ifdef NO_HIPBLASLT +#else + + hipsparseSpMatDescr_t descA; + hipsparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + HIPSPARSE_INDEX_32I, + HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); + // Create dense matrix C + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( + handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_HIPSPARSE( hipsparseSpMM(handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( hipFree(dBuffer) ); +#endif +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + if(bits == 32) + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream) +{ + + //warpsize - 32 + int num_blocks = (m+3)/4; + //warpsize - 64 + if (warpSize == 64) { + num_blocks = (m+1)/2; + } + + hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, stream, m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); + +template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, hip_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(MOMENTUM, hip_bfloat16) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(RMSPROP, hip_bfloat16) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, hip_bfloat16) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) +MAKE_optimizer32bit(ADAGRAD, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, half) +MAKE_optimizer32bit(ADEMAMIX, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) +MAKE_optimizerStatic8bit(ADAGRAD, half) +MAKE_optimizerStatic8bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh new file mode 100644 index 000000000..bcfc73e99 --- /dev/null +++ b/csrc/ops_hip.cuh @@ -0,0 +1,195 @@ +// !!! This is a file automatically generated by hipify!!! +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define CUDA_CHECK_RETURN(value) { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + +#define CHECK_HIPSPARSE(value) { \ + hipsparseStatus_t _m_hipStat = value; \ + if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + + +inline void checkHipStatus(hipError_t status) { + if (status != hipSuccess) { + printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); + throw std::logic_error("hip API failed"); + } +} + +inline int checkHipblasStatus(hipblasStatus_t status) { + if (status != HIPBLAS_STATUS_SUCCESS) { + printf("hipBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + rocblas_handle m_handle; + + Context() + { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + hipblasLtHandle_t m_handle; + + ContextLt() + { + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); + m_handle = handle; + } +}; + +class ContextHipsparse +{ + public: + hipsparseHandle_t m_handle; + + ContextHipsparse() + { + hipsparseHandle_t handle; + hipsparseCreate(&handle); + m_handle = handle; + } + +}; + + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, hipStream_t stream); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); + +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, hipStream_t stream); +void getRowStats(half * A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + +template void func(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 56bec82e8..a48514542 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -6,11 +6,29 @@ #if BUILD_CUDA #include #endif +#if BUILD_HIP +#include +#endif #if BUILD_MPS // #include #endif #include +// Compatibility between HIP/CUDA APIs +#if BUILD_HIP +#define cudaStream_t hipStream_t +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#define ContextCusparse ContextHipsparse +#define cusparseHandle_t hipsparseHandle_t +#define cudaMallocManaged hipMallocManaged +#define cudaMemAttachHost hipMemAttachHost +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess +#define cudaMemPrefetchAsync hipMemPrefetchAsync +#endif + // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // maintain all that boilerplate @@ -18,7 +36,7 @@ // UNMANGLED CALLS //=================================================================================== -#if BUILD_CUDA +#if BUILD_CUDA || BUILD_HIP void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } @@ -168,7 +186,7 @@ void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_r extern "C" { -#if BUILD_CUDA +#if BUILD_CUDA || BUILD_HIP void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } From d729c188496ce5947f159693fbbb3e2dd281d87e Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 20 May 2025 21:14:15 +0530 Subject: [PATCH 02/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 363 ++++++++++++++++++------------ 1 file changed, 223 insertions(+), 140 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..fd63c888d 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,14 +1,15 @@ from collections.abc import Sequence import ctypes as ct from math import prod -from typing import Optional +from typing import Optional import torch from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib +from ...cextension import lib, HIP_ENVIRONMENT + @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -84,7 +85,6 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor return out - @register_kernel("bitsandbytes::int8_mm_dequant", "cuda") def _( A: torch.Tensor, @@ -164,7 +164,7 @@ def _(A: torch.Tensor, threshold=0.0): out_row[:, outlier_cols] = 0 return out_row, row_stats, outlier_cols - + @register_kernel("bitsandbytes::int8_double_quant", "cuda") def _( @@ -210,35 +210,67 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - + + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' and HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + elif device_type == 'cpu': + cpu_kernel_func = getattr(lib, 'cquantize_blockwise_cpu_fp32', None) + if cpu_kernel_func: + A_cpu = A.to(torch.float32) if A.dtype != torch.float32 else A + code_cpu = code.to('cpu') + absmax_cpu = torch.empty(absmax.shape, device='cpu', dtype=torch.float32) + out_cpu = torch.empty(out.shape, device='cpu', dtype=torch.uint8) + + cpu_kernel_func( + get_ptr(code_cpu), + get_ptr(A_cpu), + get_ptr(absmax_cpu), + get_ptr(out_cpu), + ct.c_longlong(blocksize), + ct.c_longlong(A_cpu.numel()) + ) + + out.copy_(out_cpu) + absmax.copy_(absmax_cpu) + else: + raise NotImplementedError("CPU blockwise quantization requires C extension support") + else: + raise NotImplementedError(f"Blockwise quantization not implemented for {device_type}") + + return out, absmax + @register_kernel("bitsandbytes::dequantize_blockwise", "cuda") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: @@ -252,7 +284,7 @@ def _( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, - blocksize: int, + blocksize: int, dtype: torch.dtype, out: torch.Tensor, ) -> None: @@ -264,76 +296,116 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' and HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + elif device_type == 'cpu': + cpu_kernel_func = getattr(lib, 'cdequantize_blockwise_cpu_fp32', None) + if cpu_kernel_func: + code_cpu = code.to('cpu') + A_cpu = A.to('cpu') + absmax_cpu = absmax.to('cpu') + out_cpu = torch.empty(out.shape, dtype=torch.float32, device='cpu') + + cpu_kernel_func( + get_ptr(code_cpu), + get_ptr(A_cpu), + get_ptr(absmax_cpu), + get_ptr(out_cpu), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()) + ) + + out.copy_(out_cpu.to(dtype)) + else: + raise NotImplementedError("CPU blockwise dequantization requires C extension support") + else: + raise NotImplementedError(f"Blockwise dequantization not implemented for {device_type}") @register_kernel("bitsandbytes::quantize_4bit", "cuda") def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' or HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + else: + raise NotImplementedError(f"4-bit quantization not implemented for {device_type}") + + return out, absmax @register_kernel("bitsandbytes::dequantize_4bit", "cuda") def _( @@ -347,6 +419,7 @@ def _( out = torch.empty(shape, dtype=dtype, device=A.device) _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) return out + @register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") @@ -359,52 +432,62 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + device = A.device + device_type = device.type + + if device_type == 'cuda': + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + elif device_type == 'hip' and HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + else: + raise NotImplementedError(f"4-bit dequantization not implemented for {device_type}") @register_kernel("bitsandbytes::gemv_4bit", "cuda") @@ -457,7 +540,7 @@ def _gemv_4bit_impl( B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") m = ct.c_int32(shapeB[0]) From 6459c2bd6e4eb68fbe36d3deb200ac3492f96c1a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 20 May 2025 21:15:00 +0530 Subject: [PATCH 03/98] Update functional.py --- bitsandbytes/functional.py | 391 ++++++++++++++++++++++--------------- 1 file changed, 238 insertions(+), 153 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b0092ffd1..7730f7182 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib +from .cextension import lib, HIP_ENVIRONMENT name2qmap = {} @@ -719,152 +719,222 @@ def __eq__(self, other): ) -def quantize_blockwise( - A: torch.Tensor, - code: Optional[torch.Tensor] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=4096, - nested=False, -) -> tuple[torch.Tensor, QuantState]: - """Quantize a tensor in blocks of values. - - The input tensor is quantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is calculated for scaling - the non-linear quantization. - - Args: - A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. - - Raises: - ValueError: Raised when the input data type is not supported. - - Returns: - `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. - - `torch.Tensor`: The quantized tensor. - - [`QuantState`]: The state object used to undo the quantization. - """ - - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( - A, - code.to(A.device), - blocksize, - ) - - if nested: - offset = _absmax.mean() - _absmax -= offset - qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) - quant_state = QuantState( - absmax=qabsmax, - code=code, - blocksize=blocksize, - dtype=A.dtype, - offset=offset, - state2=state2, - ) - else: - quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) - - # TODO(matthewdouglas): Deprecate out kwarg - out = out.copy_(_out) if out is not None else _out - - # TODO(matthewdouglas): Deprecate absmax kwarg - if absmax is not None: - quant_state.absmax = absmax.copy_(quant_state.absmax) - - return out, quant_state - - -def dequantize_blockwise( - A: torch.Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - code: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 4096, - nested=False, -) -> torch.Tensor: - """Dequantize a tensor in blocks of values. - - The input tensor is dequantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is used for scaling - the non-linear dequantization. - - Args: - A (`torch.Tensor`): The quantized input tensor. - quant_state ([`QuantState`], *optional*): - The quantization state as returned by [`quantize_blockwise`]. - Required if `absmax` is not provided. - absmax (`torch.Tensor`, *optional*): - A tensor containing the scaling values. - Required if `quant_state` is not provided and ignored otherwise. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - Ignored when `quant_state` is provided. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - Ignored when `quant_state` is provided. - - Raises: - ValueError: Raised when the input data type is not supported. - - Returns: - `torch.Tensor`: - The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. - """ - - assert quant_state is not None or absmax is not None - if code is None and quant_state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - - absmax = quant_state.absmax - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - if out is not None: - torch.ops.bitsandbytes.dequantize_blockwise.out( - A, - absmax, - code.to(A.device), - blocksize, - quant_state.dtype, - out=out, - ) - return out - - return torch.ops.bitsandbytes.dequantize_blockwise.default( - A, - absmax, - quant_state.code.to(A.device), - quant_state.blocksize, - quant_state.dtype, - ) +def quantize_blockwise( + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> tuple[torch.Tensor, QuantState]: + """Quantize a tensor in blocks of values. + + The input tensor is quantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is calculated for scaling + the non-linear quantization. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor. + - [`QuantState`]: The state object used to undo the quantization. + """ + + if code is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if absmax is None: + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) + + device_type = A.device.type + + if device_type == "cpu": + code = code.cpu() + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + elif device_type in ["cuda", "hip"]: + if not HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] + + code = code.to(A.device) + + is_on_gpu([A, out, absmax]) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + else: + raise RuntimeError(f"Device type {device_type} not supported for quantization") + + if nested: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) + else: + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) + + return out, quant_state + + +def dequantize_blockwise( + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, +) -> torch.Tensor: + """Dequantize a tensor in blocks of values. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_blockwise`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + Ignored when `quant_state` is provided. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Ignored when `quant_state` is provided. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `torch.Tensor`: + The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. + """ + + assert quant_state is not None or absmax is not None + if code is None and quant_state is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if quant_state is None: + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + + absmax = quant_state.absmax + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is None: + out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) + + device_type = A.device.type + + if device_type == "cpu": + code = quant_state.code.cpu() + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(quant_state.absmax), + get_ptr(out), + ct.c_longlong(quant_state.blocksize), + ct.c_longlong(A.numel()), + ) + elif device_type in ["cuda", "hip"]: + code = quant_state.code.to(A.device) + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] + if HIP_ENVIRONMENT: + supported_blocksizes = supported_blocksizes[:-1] + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}", + ) + + is_on_gpu([A, absmax, out]) + + with _cuda_device_of(A): + args = ( + get_ptr(quant_state.code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif out.dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + else: + raise RuntimeError(f"Device type {device_type} not supported for dequantization") + + return out def get_4bit_type(typename, device=None, blocksize=64): @@ -953,10 +1023,12 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -964,10 +1036,12 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -975,7 +1049,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -1003,6 +1077,9 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1053,8 +1130,10 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1063,8 +1142,10 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1073,7 +1154,7 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, quant_type="fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1102,6 +1183,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ + + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + if quant_state is None: assert absmax is not None and out is not None From 09249c897e47708ea9d4e594b8deaea439d74ade Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 21 May 2025 20:12:20 +0530 Subject: [PATCH 04/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 106 +++++++++++++----------------- 1 file changed, 44 insertions(+), 62 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index fd63c888d..40f25a18f 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,3 +1,4 @@ + from collections.abc import Sequence import ctypes as ct from math import prod @@ -5,7 +6,7 @@ import torch -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr, is_on_gpu from ..._ops import register_kernel from ...cextension import lib, HIP_ENVIRONMENT @@ -43,7 +44,7 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor n = prod(shapeB[:-1]) lda = shapeA[-1] # Weights (outputs, inputs) ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) torch._check( lda == ldb, @@ -53,10 +54,18 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. # We'll fall back to a slower fp32 calculation in this circumstance. # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + if out is not None: + result = out.copy_(result) + return result + + if out is None: + out = torch.empty(shapeC, device=A.device, dtype=dtype) + + is_on_gpu([A, B, out]) + with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) ptrA = get_ptr(A) @@ -71,8 +80,11 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor ldc = ct.c_int32(ldc) stream = _get_tensor_stream(A) - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - + if dtype == torch.int32: + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + else: + has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + if has_error: if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` @@ -111,6 +123,8 @@ def _( # Note: fused bias in the kernel is only supported for fp16 # TODO(matthewdouglas): Consider supporting bf16 fused bias ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + is_on_gpu([A, row_stats, col_stats, out, bias]) with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( @@ -128,6 +142,8 @@ def _( def _(A: torch.Tensor, threshold=0.0): torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + is_on_gpu([A]) rows = prod(A.shape[:-1]) cols = A.shape[-1] @@ -216,7 +232,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor if device_type == 'cuda': torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' and HIP_ENVIRONMENT: + elif device_type == 'hip' or HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") @@ -225,8 +241,10 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) - - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + + is_on_gpu([A, out, absmax]) + + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( get_ptr(code), @@ -245,30 +263,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor lib.cquantize_blockwise_fp32(*args) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - elif device_type == 'cpu': - cpu_kernel_func = getattr(lib, 'cquantize_blockwise_cpu_fp32', None) - if cpu_kernel_func: - A_cpu = A.to(torch.float32) if A.dtype != torch.float32 else A - code_cpu = code.to('cpu') - absmax_cpu = torch.empty(absmax.shape, device='cpu', dtype=torch.float32) - out_cpu = torch.empty(out.shape, device='cpu', dtype=torch.uint8) - - cpu_kernel_func( - get_ptr(code_cpu), - get_ptr(A_cpu), - get_ptr(absmax_cpu), - get_ptr(out_cpu), - ct.c_longlong(blocksize), - ct.c_longlong(A_cpu.numel()) - ) - - out.copy_(out_cpu) - absmax.copy_(absmax_cpu) - else: - raise NotImplementedError("CPU blockwise quantization requires C extension support") - else: - raise NotImplementedError(f"Blockwise quantization not implemented for {device_type}") - + return out, absmax @@ -302,7 +297,7 @@ def _dequantize_blockwise_impl( if device_type == 'cuda': torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' and HIP_ENVIRONMENT: + elif device_type == 'hip' or HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") @@ -310,8 +305,10 @@ def _dequantize_blockwise_impl( dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", ) - - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + + is_on_gpu([A, absmax, out]) + + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( get_ptr(code), @@ -328,29 +325,8 @@ def _dequantize_blockwise_impl( elif dtype == torch.bfloat16: lib.cdequantize_blockwise_bf16(*args) elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - elif device_type == 'cpu': - cpu_kernel_func = getattr(lib, 'cdequantize_blockwise_cpu_fp32', None) - if cpu_kernel_func: - code_cpu = code.to('cpu') - A_cpu = A.to('cpu') - absmax_cpu = absmax.to('cpu') - out_cpu = torch.empty(out.shape, dtype=torch.float32, device='cpu') - - cpu_kernel_func( - get_ptr(code_cpu), - get_ptr(A_cpu), - get_ptr(absmax_cpu), - get_ptr(out_cpu), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()) - ) - - out.copy_(out_cpu.to(dtype)) - else: - raise NotImplementedError("CPU blockwise dequantization requires C extension support") - else: - raise NotImplementedError(f"Blockwise dequantization not implemented for {device_type}") + lib.cdequantize_blockwise_fp32(*args) + @register_kernel("bitsandbytes::quantize_4bit", "cuda") def _( @@ -375,7 +351,9 @@ def _( blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - + + is_on_gpu([A, out, absmax]) + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -450,7 +428,7 @@ def _dequantize_4bit_impl( if device_type == 'cuda': torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' and HIP_ENVIRONMENT: + elif device_type == 'hip' or HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) torch._check(quant_type in ["fp4", "nf4"]) @@ -459,6 +437,8 @@ def _dequantize_4bit_impl( lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) + is_on_gpu([A, absmax, out]) + if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -550,6 +530,8 @@ def _gemv_4bit_impl( lda = m ldb = ct.c_int32((A.shape[-1] + 1) // 2) ldc = m + + is_on_gpu([B, A, out, absmax]) stream = _get_tensor_stream(A) From 4afa7741b3b7105ac6a42700dab1fd83b5050fc5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 21 May 2025 20:12:36 +0530 Subject: [PATCH 05/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 40f25a18f..ce5401c5f 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,4 +1,3 @@ - from collections.abc import Sequence import ctypes as ct from math import prod From 033d92cef2d41431fd4247c272c9429f7304bf40 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 21 May 2025 20:23:34 +0530 Subject: [PATCH 06/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index ce5401c5f..14f55847c 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -353,7 +353,7 @@ def _( is_on_gpu([A, out, absmax]) - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( None, @@ -438,7 +438,7 @@ def _dequantize_4bit_impl( is_on_gpu([A, absmax, out]) - if device_type == 'cuda' or (device_type == 'hip' and HIP_ENVIRONMENT): + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( None, From 4def9590abb8a3f0ef789fce0b1659af729643e4 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 20:51:50 +0530 Subject: [PATCH 07/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14f55847c..5b94c5349 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -11,7 +11,6 @@ from ...cextension import lib, HIP_ENVIRONMENT - @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") def _(A: torch.Tensor, B: torch.Tensor): out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) @@ -78,12 +77,9 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) stream = _get_tensor_stream(A) - - if dtype == torch.int32: - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - else: - has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + if has_error: if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` @@ -96,6 +92,7 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor return out + @register_kernel("bitsandbytes::int8_mm_dequant", "cuda") def _( A: torch.Tensor, @@ -384,6 +381,7 @@ def _( return out, absmax + @register_kernel("bitsandbytes::dequantize_4bit", "cuda") def _( A: torch.Tensor, @@ -398,7 +396,6 @@ def _( return out - @register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") def _( A: torch.Tensor, @@ -496,7 +493,6 @@ def _( torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - def _gemv_4bit_impl( A: torch.Tensor, B: torch.Tensor, From 0f318667aaf4de15cd29f8063dcaa4fd90d24783 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 21:31:55 +0530 Subject: [PATCH 08/98] Update functional.py --- bitsandbytes/functional.py | 157 ++++++++++++------------------------- 1 file changed, 48 insertions(+), 109 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7730f7182..3f0c1ff94 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -728,11 +728,9 @@ def quantize_blockwise( nested=False, ) -> tuple[torch.Tensor, QuantState]: """Quantize a tensor in blocks of values. - The input tensor is quantized by dividing it into blocks of `blocksize` values. The the absolute maximum value within these blocks is calculated for scaling the non-linear quantization. - Args: A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. code (`torch.Tensor`, *optional*): @@ -744,10 +742,8 @@ def quantize_blockwise( The size of the blocks. Defaults to 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. - Raises: ValueError: Raised when the input data type is not supported. - Returns: `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. - `torch.Tensor`: The quantized tensor. @@ -759,61 +755,23 @@ def quantize_blockwise( name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - if absmax is None: - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) - device_type = A.device.type - - if device_type == "cpu": - code = code.cpu() - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) - elif device_type in ["cuda", "hip"]: + if device_type in ["cuda", "hip"]: if not HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] - - code = code.to(A.device) - - is_on_gpu([A, out, absmax]) + assert blocksize in [4096, 2048, 1024, 512, 256, 128] - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - else: - raise RuntimeError(f"Device type {device_type} not supported for quantization") + _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( + A, + code.to(A.device), + blocksize, + ) if nested: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + offset = _absmax.mean() + _absmax -= offset + qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) quant_state = QuantState( absmax=qabsmax, code=code, @@ -823,11 +781,18 @@ def quantize_blockwise( state2=state2, ) else: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) + quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) + + # TODO(matthewdouglas): Deprecate out kwarg + out = out.copy_(_out) if out is not None else _out + + # TODO(matthewdouglas): Deprecate absmax kwarg + if absmax is not None: + quant_state.absmax = absmax.copy_(quant_state.absmax) + + return out, quant_state + - return out, quant_state - - def dequantize_blockwise( A: torch.Tensor, quant_state: Optional[QuantState] = None, @@ -838,11 +803,9 @@ def dequantize_blockwise( nested=False, ) -> torch.Tensor: """Dequantize a tensor in blocks of values. - The input tensor is dequantized by dividing it into blocks of `blocksize` values. The the absolute maximum value within these blocks is used for scaling the non-linear dequantization. - Args: A (`torch.Tensor`): The quantized input tensor. quant_state ([`QuantState`], *optional*): @@ -860,10 +823,8 @@ def dequantize_blockwise( The size of the blocks. Defaults to 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. Ignored when `quant_state` is provided. - Raises: ValueError: Raised when the input data type is not supported. - Returns: `torch.Tensor`: The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. @@ -878,6 +839,16 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + device_type = A.device.type + if device_type in ["cuda", "hip"]: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] + if HIP_ENVIRONMENT: + supported_blocksizes = supported_blocksizes[:-1] + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" + ) + absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) @@ -885,56 +856,24 @@ def dequantize_blockwise( if absmax.dtype != torch.float32: absmax = absmax.float() - if out is None: - out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - - device_type = A.device.type - - if device_type == "cpu": - code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(quant_state.absmax), - get_ptr(out), - ct.c_longlong(quant_state.blocksize), - ct.c_longlong(A.numel()), + if out is not None: + torch.ops.bitsandbytes.dequantize_blockwise.out( + A, + absmax, + quant_state.code.to(A.device), + quant_state.blocksize, + quant_state.dtype, + out=out, ) - elif device_type in ["cuda", "hip"]: - code = quant_state.code.to(A.device) - supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] - if HIP_ENVIRONMENT: - supported_blocksizes = supported_blocksizes[:-1] - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}", - ) - - is_on_gpu([A, absmax, out]) - - with _cuda_device_of(A): - args = ( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - else: - raise RuntimeError(f"Device type {device_type} not supported for dequantization") + return out - return out + return torch.ops.bitsandbytes.dequantize_blockwise.default( + A, + absmax, + quant_state.code.to(A.device), + quant_state.blocksize, + quant_state.dtype, + ) def get_4bit_type(typename, device=None, blocksize=64): From 190faed7e96b8b27e033fe3c6ee5e3a6d5a4772a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 23:35:15 +0530 Subject: [PATCH 09/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 5b94c5349..ff5e023cc 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -52,15 +52,9 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. # We'll fall back to a slower fp32 calculation in this circumstance. # Fortunately, this should not be very common. - - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - if out is not None: - result = out.copy_(result) - return result - - if out is None: - out = torch.empty(shapeC, device=A.device, dtype=dtype) + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) is_on_gpu([A, B, out]) From d7f413b9b367b9b26b87180095ebcc7a561fdc26 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 22 May 2025 23:52:39 +0530 Subject: [PATCH 10/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 30 +++++------------------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index ff5e023cc..b75f67d62 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -55,8 +55,6 @@ def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor if lda % 4 != 0: result = torch.matmul(B.float(), A.float().t()).to(torch.int32) return out.copy_(result) - - is_on_gpu([A, B, out]) with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) @@ -114,8 +112,6 @@ def _( # TODO(matthewdouglas): Consider supporting bf16 fused bias ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - is_on_gpu([A, row_stats, col_stats, out, bias]) - with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) @@ -133,8 +129,6 @@ def _(A: torch.Tensor, threshold=0.0): torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - is_on_gpu([A]) - rows = prod(A.shape[:-1]) cols = A.shape[-1] @@ -231,9 +225,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) - - is_on_gpu([A, out, absmax]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -295,9 +287,7 @@ def _dequantize_blockwise_impl( dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", ) - - is_on_gpu([A, absmax, out]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -341,9 +331,7 @@ def _( blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - is_on_gpu([A, out, absmax]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -370,8 +358,6 @@ def _( lib.cquantize_blockwise_fp32_fp4(*args) else: lib.cquantize_blockwise_fp32_nf4(*args) - else: - raise NotImplementedError(f"4-bit quantization not implemented for {device_type}") return out, absmax @@ -400,11 +386,11 @@ def _( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, @@ -426,9 +412,7 @@ def _dequantize_4bit_impl( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) - - is_on_gpu([A, absmax, out]) - + if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): with _cuda_device_of(A): args = ( @@ -456,8 +440,6 @@ def _dequantize_4bit_impl( lib.cdequantize_blockwise_fp32_fp4(*args) else: lib.cdequantize_blockwise_fp32_nf4(*args) - else: - raise NotImplementedError(f"4-bit dequantization not implemented for {device_type}") @register_kernel("bitsandbytes::gemv_4bit", "cuda") @@ -520,8 +502,6 @@ def _gemv_4bit_impl( ldb = ct.c_int32((A.shape[-1] + 1) // 2) ldc = m - is_on_gpu([B, A, out, absmax]) - stream = _get_tensor_stream(A) with _cuda_device_of(A): From 3b6e68a001b0dce3b368129599335fcb569ac5cd Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 00:05:43 +0530 Subject: [PATCH 11/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index b75f67d62..156125c9f 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr, is_on_gpu +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel from ...cextension import lib, HIP_ENVIRONMENT From 06740b1372a9c9751216b76dc4c8cc98514905dd Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 01:53:30 +0530 Subject: [PATCH 12/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 989 +++++++++++++++--------------- 1 file changed, 486 insertions(+), 503 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 156125c9f..48dc75135 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,325 +1,312 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - device = A.device - device_type = device.type +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib, HIP_ENVIRONMENT + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty_like(A, dtype=torch.uint8) - - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - - device = A.device - device_type = device.type - - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + if HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", ) - - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - - device = A.device - device_type = device.type - - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( @@ -331,66 +318,65 @@ def _( blocks = -(n // -blocksize) absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, @@ -399,157 +385,154 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - device = A.device - device_type = device.type - - if device_type == 'cuda': - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - elif device_type == 'hip' or HIP_ENVIRONMENT: + if HIP_ENVIRONMENT: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m - if device_type == 'cuda' or (device_type == 'hip' or HIP_ENVIRONMENT): - with _cuda_device_of(A): - args = ( - None, + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, get_ptr(A), + get_ptr(B), get_ptr(absmax), + get_ptr(code), get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) From 9fe67efada457a759d1d8193265243209e784e2c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 02:11:31 +0530 Subject: [PATCH 13/98] Update functional.py --- bitsandbytes/functional.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3f0c1ff94..237aa3e54 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -754,13 +754,11 @@ def quantize_blockwise( if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - - device_type = A.device.type - if device_type in ["cuda", "hip"]: - if not HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] + + if HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( A, @@ -839,15 +837,15 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - device_type = A.device.type - if device_type in ["cuda", "hip"]: + if HIP_ENVIRONMENT: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] + else: supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] - if HIP_ENVIRONMENT: - supported_blocksizes = supported_blocksizes[:-1] - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" - ) + + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" + ) absmax = quant_state.absmax if quant_state.nested: From d97fdce654129ca156f0cb47555529d4f4941778 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 23 May 2025 02:18:37 +0530 Subject: [PATCH 14/98] Update functional.py --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 237aa3e54..1cee234ea 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -858,8 +858,8 @@ def dequantize_blockwise( torch.ops.bitsandbytes.dequantize_blockwise.out( A, absmax, - quant_state.code.to(A.device), - quant_state.blocksize, + code.to(A.device), + blocksize, quant_state.dtype, out=out, ) From f1fbe92d2bc2eebc4629ee41a76b163772cd1874 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Sat, 24 May 2025 21:53:44 +0530 Subject: [PATCH 15/98] Update functional.py --- bitsandbytes/functional.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 1cee234ea..b51258420 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -960,12 +960,12 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=None, + blocksize=64, compress_statistics=False, quant_storage=torch.uint8, ): - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -973,12 +973,12 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=None, + blocksize=64, compress_statistics=False, quant_storage=torch.uint8, ): - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -986,7 +986,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=None, + blocksize=64, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -1014,8 +1014,8 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 input_shape = A.shape @@ -1067,10 +1067,10 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: Optional[int] = None, + blocksize: int = 64, ) -> torch.Tensor: - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1079,10 +1079,10 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: Optional[int] = None, + blocksize: int = 64, ) -> torch.Tensor: - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1091,7 +1091,7 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: Optional[int] = None, + blocksize: int = 64, quant_type="fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1121,8 +1121,8 @@ def dequantize_4bit( `torch.Tensor`: The dequantized tensor. """ - if blocksize is None: - blocksize = 64 if not HIP_ENVIRONMENT else 128 + if HIP_ENVIRONMENT: + blocksize = 128 if quant_state is None: assert absmax is not None and out is not None From 660c25448edcff9f0f56368cc9ef04e91045d52c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Sat, 24 May 2025 21:57:22 +0530 Subject: [PATCH 16/98] Update functional.py --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b51258420..2ae977e7a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -986,7 +986,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=64, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, From c692f4bc8f604f50a8a4f4409d373ed70c630364 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 27 May 2025 21:45:04 +0530 Subject: [PATCH 17/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 48dc75135..14878123a 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -3,7 +3,7 @@ from math import prod from typing import Optional -import torch +import torch from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr From 46f9800d9e9a361ecabf1051f99776fbfc73589d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 27 May 2025 21:55:36 +0530 Subject: [PATCH 18/98] Update ops.py From 7823bac2c0c234c468392c219b29ed51dea8ca96 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:12:42 +0530 Subject: [PATCH 19/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 ++++++++++++++--------------- 1 file changed, 521 insertions(+), 538 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14878123a..efdef2871 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,538 +1,521 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From d0ed1077d910acc4cd6f3ec4c57cf597931ff20c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:14:34 +0530 Subject: [PATCH 20/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 +++++++++++++++-------------- 1 file changed, 538 insertions(+), 521 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..14878123a 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,521 +1,538 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib, HIP_ENVIRONMENT + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From af3aaf6a5d5ee90d713fbba875ab3cbd5137c619 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:17:20 +0530 Subject: [PATCH 21/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14878123a..aa7c82f09 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -536,3 +536,4 @@ def _gemv_4bit_impl( ct.c_int32(blocksize), stream, ) + From d1e34a5dfe80aa95c42de7187800468d7a9e1b8a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:18:53 +0530 Subject: [PATCH 22/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1058 ++++++++++++++--------------- 1 file changed, 520 insertions(+), 538 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index aa7c82f09..efdef2871 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,539 +1,521 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From b2b4df6d3046a166d6e177de2dbca26f1b0abcab Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:21:15 +0530 Subject: [PATCH 23/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 +++++++++++++++-------------- 1 file changed, 538 insertions(+), 521 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..14878123a 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,521 +1,538 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib, HIP_ENVIRONMENT + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From 8863d0e3d55c73478926c9388080750be2e49690 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:22:01 +0530 Subject: [PATCH 24/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 1059 ++++++++++++++--------------- 1 file changed, 521 insertions(+), 538 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 14878123a..efdef2871 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -1,538 +1,521 @@ -from collections.abc import Sequence -import ctypes as ct -from math import prod -from typing import Optional - -import torch - -from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr - -from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT - - -@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") -def _(A: torch.Tensor, B: torch.Tensor): - out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) - return _int8_linear_matmul_impl(A, B, out) - - -@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") -def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - _int8_linear_matmul_impl(A, B, out) - - -def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - torch._check(A.dtype == torch.int8, lambda: "B must be int8") - torch._check(B.dtype == torch.int8, lambda: "A must be int8") - torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") - torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") - torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") - torch._check(out.dtype == torch.int32) - - shapeC = (*shapeB[:-1], shapeA[0]) - torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - torch._check( - lda == ldb, - lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", - ) - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - return out.copy_(result) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error: - if has_error == 100: - # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - # TODO: Warn and implement a fallback to fp32 compute? - raise NotImplementedError("int8_linear_matmul not implemented!") - else: - raise RuntimeError( - f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" - ) - - return out - - -@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - # Note: cuda kernel only currently supports fp16 output. - # We'll later cast to desired dtype if needed. - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - # Note: fused bias in the kernel is only supported for fp16 - # TODO(matthewdouglas): Consider supporting bf16 fused bias - ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) - - # Add bias separately if not fused in kernel - if bias is not None and bias.dtype != torch.float16: - out.add_(bias) - - return out.to(dtype or torch.float16) - - -@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") -def _(A: torch.Tensor, threshold=0.0): - torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") - torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - else: - # Needed for torch.compile support. - outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@register_kernel("bitsandbytes::int8_double_quant", "cuda") -def _( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( - A, - threshold=threshold, - ) - - # PyTorch impl for colwise - col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols - - -def _get_col_absmax( - A: torch.Tensor, - threshold=0.0, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - torch._check(A.is_floating_point()) - - outlier_mask = None - - absA = A.abs().view(-1, A.shape[-1]) - - if threshold > 0.0: - # Filter outliers from stats when enabled - outlier_mask = absA >= threshold - absA.masked_fill_(outlier_mask, 0.0) - - # shape [cols]; unsqueeze(0) gives [1,cols] - col_stats = absA.amax(dim=0, keepdim=False).float() - - return col_stats, outlier_mask - - -@register_kernel("bitsandbytes::quantize_blockwise", "cuda") -def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") -def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = torch.empty_like(A, dtype=dtype) - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") - _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) - - -def _dequantize_blockwise_impl( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check( - dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - - -@register_kernel("bitsandbytes::quantize_4bit", "cuda") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - - return out, absmax - - -@register_kernel("bitsandbytes::dequantize_4bit", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - return out - - -@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") - torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") - _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) - - -def _dequantize_4bit_impl( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - dtype: torch.dtype, - out: torch.Tensor, -) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - - torch._check(quant_type in ["fp4", "nf4"]) - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(out.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - - -@register_kernel("bitsandbytes::gemv_4bit", "cuda") -def _( - A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int -) -> torch.Tensor: - shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(shape, device=A.device, dtype=A.dtype) - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - return out - - -@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", - ) - torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") - _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) - - -def _gemv_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - out: torch.Tensor, -) -> None: - torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) - torch._check( - A.dtype in [torch.float16, torch.bfloat16, torch.float32], - lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - ) - torch._check( - B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - ) - torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") - - m = ct.c_int32(shapeB[0]) - n = ct.c_int32(1) - k = ct.c_int32(shapeB[1]) - - lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) - ldc = m - - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(blocksize), - stream, - ) +from collections.abc import Sequence +import ctypes as ct +from math import prod +from typing import Optional + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n\t{shapeA=}, {shapeB=}, {shapeC=}\n\t{(lda, ldb, ldc)=}\n\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype or torch.float16) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) From d1a5e8dec4e212e5c722d884809d5645c4772a1b Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:35:33 +0530 Subject: [PATCH 25/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..fd7b7b9a2 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib +from ...cextension import lib, HIP_ENVIRONMENT @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -210,7 +210,12 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() @@ -264,7 +269,11 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], @@ -294,7 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -372,7 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From 843ea338f968e06d586ac70c68e70b3a2c56c228 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:39:54 +0530 Subject: [PATCH 26/98] Update functional.py --- bitsandbytes/functional.py | 316 +++++++++++++++++-------------------- 1 file changed, 147 insertions(+), 169 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2ae977e7a..b0092ffd1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib, HIP_ENVIRONMENT +from .cextension import lib name2qmap = {} @@ -719,159 +719,152 @@ def __eq__(self, other): ) -def quantize_blockwise( - A: torch.Tensor, - code: Optional[torch.Tensor] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=4096, - nested=False, -) -> tuple[torch.Tensor, QuantState]: - """Quantize a tensor in blocks of values. - The input tensor is quantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is calculated for scaling - the non-linear quantization. - Args: - A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. - Raises: - ValueError: Raised when the input data type is not supported. - Returns: - `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. - - `torch.Tensor`: The quantized tensor. - - [`QuantState`]: The state object used to undo the quantization. - """ - - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( - A, - code.to(A.device), - blocksize, - ) - - if nested: - offset = _absmax.mean() - _absmax -= offset - qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) - quant_state = QuantState( - absmax=qabsmax, - code=code, - blocksize=blocksize, - dtype=A.dtype, - offset=offset, - state2=state2, - ) - else: - quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) - - # TODO(matthewdouglas): Deprecate out kwarg - out = out.copy_(_out) if out is not None else _out - - # TODO(matthewdouglas): Deprecate absmax kwarg - if absmax is not None: - quant_state.absmax = absmax.copy_(quant_state.absmax) - - return out, quant_state - - -def dequantize_blockwise( - A: torch.Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - code: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 4096, - nested=False, -) -> torch.Tensor: - """Dequantize a tensor in blocks of values. - The input tensor is dequantized by dividing it into blocks of `blocksize` values. - The the absolute maximum value within these blocks is used for scaling - the non-linear dequantization. - Args: - A (`torch.Tensor`): The quantized input tensor. - quant_state ([`QuantState`], *optional*): - The quantization state as returned by [`quantize_blockwise`]. - Required if `absmax` is not provided. - absmax (`torch.Tensor`, *optional*): - A tensor containing the scaling values. - Required if `quant_state` is not provided and ignored otherwise. - code (`torch.Tensor`, *optional*): - A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. - For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - Ignored when `quant_state` is provided. - out (`torch.Tensor`, *optional*): A tensor to use to store the result. - blocksize (`int`, *optional*): - The size of the blocks. Defaults to 4096. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - Ignored when `quant_state` is provided. - Raises: - ValueError: Raised when the input data type is not supported. - Returns: - `torch.Tensor`: - The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. - """ - - assert quant_state is not None or absmax is not None - if code is None and quant_state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - - if quant_state is None: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - - if HIP_ENVIRONMENT: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] - else: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] - - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" - ) - - absmax = quant_state.absmax - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - if out is not None: - torch.ops.bitsandbytes.dequantize_blockwise.out( - A, - absmax, - code.to(A.device), - blocksize, - quant_state.dtype, - out=out, - ) - return out - - return torch.ops.bitsandbytes.dequantize_blockwise.default( - A, - absmax, - quant_state.code.to(A.device), - quant_state.blocksize, - quant_state.dtype, - ) +def quantize_blockwise( + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> tuple[torch.Tensor, QuantState]: + """Quantize a tensor in blocks of values. + + The input tensor is quantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is calculated for scaling + the non-linear quantization. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor. + - [`QuantState`]: The state object used to undo the quantization. + """ + + if code is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( + A, + code.to(A.device), + blocksize, + ) + + if nested: + offset = _absmax.mean() + _absmax -= offset + qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) + quant_state = QuantState( + absmax=qabsmax, + code=code, + blocksize=blocksize, + dtype=A.dtype, + offset=offset, + state2=state2, + ) + else: + quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) + + # TODO(matthewdouglas): Deprecate out kwarg + out = out.copy_(_out) if out is not None else _out + + # TODO(matthewdouglas): Deprecate absmax kwarg + if absmax is not None: + quant_state.absmax = absmax.copy_(quant_state.absmax) + + return out, quant_state + + +def dequantize_blockwise( + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, + nested=False, +) -> torch.Tensor: + """Dequantize a tensor in blocks of values. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_blockwise`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + code (`torch.Tensor`, *optional*): + A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. + For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. + Ignored when `quant_state` is provided. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 4096. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Ignored when `quant_state` is provided. + + Raises: + ValueError: Raised when the input data type is not supported. + + Returns: + `torch.Tensor`: + The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`. + """ + + assert quant_state is not None or absmax is not None + if code is None and quant_state is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if quant_state is None: + quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + + absmax = quant_state.absmax + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is not None: + torch.ops.bitsandbytes.dequantize_blockwise.out( + A, + absmax, + code.to(A.device), + blocksize, + quant_state.dtype, + out=out, + ) + return out + + return torch.ops.bitsandbytes.dequantize_blockwise.default( + A, + absmax, + quant_state.code.to(A.device), + quant_state.blocksize, + quant_state.dtype, + ) def get_4bit_type(typename, device=None, blocksize=64): @@ -964,8 +957,6 @@ def quantize_fp4( compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -977,8 +968,6 @@ def quantize_nf4( compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -1014,9 +1003,6 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - if HIP_ENVIRONMENT: - blocksize = 128 - input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1069,8 +1055,6 @@ def dequantize_fp4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1081,8 +1065,6 @@ def dequantize_nf4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1120,10 +1102,6 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ - - if HIP_ENVIRONMENT: - blocksize = 128 - if quant_state is None: assert absmax is not None and out is not None From d6d2e5f32ffd30070c45f89704b8db20f600b577 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 12:57:37 +0530 Subject: [PATCH 27/98] Update functional.py --- bitsandbytes/functional.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b0092ffd1..959eeb33a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib +from .cextension import lib, HIP_ENVIRONMENT name2qmap = {} @@ -758,6 +758,11 @@ def quantize_blockwise( if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] + + if HIP_ENVIRONMENT: + assert blocksize in [4096, 2048, 1024, 512, 256, 128] + else: + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( A, @@ -839,6 +844,16 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) + + if HIP_ENVIRONMENT: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] + else: + supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] + + if quant_state.blocksize not in supported_blocksizes: + raise ValueError( + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" + ) absmax = quant_state.absmax if quant_state.nested: @@ -957,6 +972,8 @@ def quantize_fp4( compress_statistics=False, quant_storage=torch.uint8, ): + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -968,6 +985,8 @@ def quantize_nf4( compress_statistics=False, quant_storage=torch.uint8, ): + if HIP_ENVIRONMENT: + blocksize = 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -1003,6 +1022,9 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ + if HIP_ENVIRONMENT: + blocksize = 128 + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1055,6 +1077,8 @@ def dequantize_fp4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1065,6 +1089,8 @@ def dequantize_nf4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: + if HIP_ENVIRONMENT: + blocksize = 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1102,6 +1128,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ + + if HIP_ENVIRONMENT: + blocksize = 128 + if quant_state is None: assert absmax is not None and out is not None From e3f9f21236ac76cac026eacf1da26f15e7a0ad1f Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 13:23:18 +0530 Subject: [PATCH 28/98] Update functional.py --- bitsandbytes/functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 959eeb33a..f4be0dc2f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1022,6 +1022,7 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ + if HIP_ENVIRONMENT: blocksize = 128 From bc0957daa57fc1364f914c2928bcfb730f97dc9d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 17:26:33 +0530 Subject: [PATCH 29/98] Update test_ops.py --- tests/test_ops.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 4da1663f0..bb49c7dbb 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -5,6 +5,7 @@ import bitsandbytes from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter +from bitsandbytes.cextension import HIP_ENVIRONMENT class TestLLMInt8Ops: @@ -95,7 +96,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -119,7 +120,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -145,7 +146,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -169,7 +170,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": if quant_type != "nf4": @@ -206,7 +207,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": pytest.xfail("CPU implementation is not available") From b8247ab109de936bcefb932b7d0ed996168f8445 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 17:34:22 +0530 Subject: [PATCH 30/98] Update test_functional.py --- tests/test_functional.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 96e77e4f4..3b9b53a24 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,6 +8,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes import functional as F from tests.helpers import ( BOOLEAN_TUPLES, @@ -91,7 +92,7 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) + @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128] if HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128, 64] ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -147,7 +148,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") @pytest.mark.parametrize("hidden", [128]) - @pytest.mark.parametrize("blocksize", [4096, 16384]) + @pytest.mark.parametrize("blocksize", [4096] if HIP_ENVIRONMENT else [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): diffs = [] reldiffs = [] @@ -1105,7 +1106,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("blocksize", [128, 256, 512, 1024, 2048, 4096] if HIP_ENVIRONMENT else [64, 128, 256, 512, 1024, 2048, 4096]) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From 531758a10835e68a10002eb825383a1a0608cb65 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 20:19:07 +0530 Subject: [PATCH 31/98] Update test_ops.py --- tests/test_ops.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index bb49c7dbb..a99d080b3 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -96,7 +96,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -120,7 +120,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -146,7 +146,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -170,7 +170,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": if quant_type != "nf4": @@ -207,7 +207,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512] if HIP_ENVIRONMENT else [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "cpu": pytest.xfail("CPU implementation is not available") From 6d7db8efa3a2d249434378ab09f3e9f5c0d72c26 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 20:29:23 +0530 Subject: [PATCH 32/98] Update test_functional.py --- tests/test_functional.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 3b9b53a24..4b62c2567 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -92,7 +92,7 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128] if HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128, 64] ) + @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128] ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -148,7 +148,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") @pytest.mark.parametrize("hidden", [128]) - @pytest.mark.parametrize("blocksize", [4096] if HIP_ENVIRONMENT else [4096, 16384]) + @pytest.mark.parametrize("blocksize", [4096, 16384] if not HIP_ENVIRONMENT else [4096]) def test_blockwise_cpu_large(self, hidden, blocksize): diffs = [] reldiffs = [] @@ -1106,7 +1106,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [128, 256, 512, 1024, 2048, 4096] if HIP_ENVIRONMENT else [64, 128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096]) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1205,7 +1205,10 @@ def test_bench_4bit_dequant(self, quant_type): # torch.matmul(b, a.t()) # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) - + + @pytest.mark.skipif( + HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" + ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @@ -1369,6 +1372,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) + @pytest.mark.skipif( + HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet", + ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): if device == "cpu" and storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From 632e95b92d9feba37401ede69ad119017b50ae9d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 21:05:21 +0530 Subject: [PATCH 33/98] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 4b62c2567..7ad604d9f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1373,7 +1373,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.skipif( - HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet", + HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet" ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): if device == "cpu" and storage_type != "nf4": From 90d9af2c387f05bcf4dc8d409a0ac3e4ef0d8e95 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 28 May 2025 22:04:55 +0530 Subject: [PATCH 34/98] Update functional.py --- bitsandbytes/functional.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f4be0dc2f..2405a1985 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -968,12 +968,12 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -981,12 +981,12 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -994,7 +994,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -1023,8 +1023,8 @@ def quantize_4bit( - [`QuantState`]: The state object used to undo the quantization. """ - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 input_shape = A.shape @@ -1076,10 +1076,10 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -1088,10 +1088,10 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1100,7 +1100,7 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, quant_type="fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1130,8 +1130,8 @@ def dequantize_4bit( `torch.Tensor`: The dequantized tensor. """ - if HIP_ENVIRONMENT: - blocksize = 128 + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 if quant_state is None: assert absmax is not None and out is not None From 80048d89f249509db4c1fb482ce7694fcca3fdcb Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 01:38:52 +0530 Subject: [PATCH 35/98] Update functional.py --- bitsandbytes/functional.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2405a1985..03f6c323d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -758,11 +758,6 @@ def quantize_blockwise( if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - - if HIP_ENVIRONMENT: - assert blocksize in [4096, 2048, 1024, 512, 256, 128] - else: - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( A, @@ -844,16 +839,6 @@ def dequantize_blockwise( if quant_state is None: quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32) - - if HIP_ENVIRONMENT: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128] - else: - supported_blocksizes = [4096, 2048, 1024, 512, 256, 128, 64] - - if quant_state.blocksize not in supported_blocksizes: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: {supported_blocksizes}" - ) absmax = quant_state.absmax if quant_state.nested: From e448ebbadf0313f429005001791c56d092992f01 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 02:40:56 +0530 Subject: [PATCH 36/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index fd7b7b9a2..f03d06599 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -303,11 +303,7 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -385,11 +381,7 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From 048faa8ce60088fedc05474157c6356b14c3ee80 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 02:41:52 +0530 Subject: [PATCH 37/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index f03d06599..29dddc96e 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -381,7 +381,7 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From c45e9d18c9fa55135cdaea92b68a4e8660d80bf6 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 02:44:51 +0530 Subject: [PATCH 38/98] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 7ad604d9f..07c0d4964 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -148,7 +148,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") @pytest.mark.parametrize("hidden", [128]) - @pytest.mark.parametrize("blocksize", [4096, 16384] if not HIP_ENVIRONMENT else [4096]) + @pytest.mark.parametrize("blocksize", [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): diffs = [] reldiffs = [] From 47a491fb213b5286e0ed3cc9af773bf02f416f24 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:36:25 +0530 Subject: [PATCH 39/98] Update test_functional.py --- tests/test_functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 07c0d4964..2219efa2f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,7 +8,7 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from bitsandbytes import functional as F from tests.helpers import ( BOOLEAN_TUPLES, @@ -1373,7 +1373,8 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.skipif( - HIP_ENVIRONMENT, reason="this test is not supported on ROCm with gfx90a architecture yet" + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", + reason="this test is not supported on ROCm with gfx90a architecture yet", ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): if device == "cpu" and storage_type != "nf4": From 86976bc22b04bc1415a13648582e453ce594700c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:38:53 +0530 Subject: [PATCH 40/98] Update cextension.py --- bitsandbytes/cextension.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index c8b02fb22..108aa0c9a 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -8,7 +8,7 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch logger = logging.getLogger(__name__) @@ -298,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary: return BNBNativeLibrary(dll) +ROCM_GPU_ARCH = get_rocm_gpu_arch() + try: if torch.version.hip: HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" From 98a142a7c7961fc58c0b90b388f080d56991b94c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:41:51 +0530 Subject: [PATCH 41/98] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 64903cd49..da34dd608 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,9 @@ import dataclasses +import logging +import re +import subprocess from functools import lru_cache -from typing import Optional +from typing import Optional, List, Tuple import torch @@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]: ) except Exception: return None + + +def get_rocm_gpu_arch() -> str: + """Get ROCm GPU architecture.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" From 888fe46fee6fe59f377e4c4a3f19468a06094b91 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 03:59:01 +0530 Subject: [PATCH 42/98] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index da34dd608..61d03083c 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -3,7 +3,7 @@ import re import subprocess from functools import lru_cache -from typing import Optional, List, Tuple +from typing import Optional import torch From c9c52b56c1145d9ecd6ccfc4833799eae3bb2ccd Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Thu, 29 May 2025 15:59:13 +0530 Subject: [PATCH 43/98] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 2219efa2f..41ed7c984 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1141,7 +1141,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From fc29586e8951cbe41aa5693ba0cd3ae3d25b05db Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:23:38 +0530 Subject: [PATCH 44/98] Update test_linear4bit.py --- tests/test_linear4bit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 67b61cb05..474a00a1b 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -7,6 +7,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer storage = { @@ -16,7 +17,7 @@ "float32": torch.float32, } - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) From 53b8b1c580093e39d43d0018fa47abee6966442c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:27:39 +0530 Subject: [PATCH 45/98] Update test_cuda_setup_evaluator.py --- tests/test_cuda_setup_evaluator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 79406472e..1b2ea85db 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,6 +1,6 @@ import pytest -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs @@ -12,12 +12,12 @@ def cuda120_spec() -> CUDASpecs: cuda_version_tuple=(12, 0), ) - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" From fe1fe7ccd0ab1c2a41da85d865e467de691cefac Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:34:11 +0530 Subject: [PATCH 46/98] Update test_functional.py --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 41ed7c984..5f5ee488c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -796,7 +796,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): A[:, outlier_cols] = 0 torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) - +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSpMMFunctional: @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) From e198824c5c9e23bb15d6eb2aa07a04f09e95446f Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:36:53 +0530 Subject: [PATCH 47/98] Update modules.py --- bitsandbytes/nn/modules.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 937084cf1..6b6490265 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -212,7 +212,7 @@ def __new__( data: Optional[torch.Tensor] = None, requires_grad=False, # quantized weights should be frozen by default quant_state: Optional[QuantState] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, compress_statistics: bool = True, quant_type: str = "fp4", quant_storage: torch.dtype = torch.uint8, @@ -221,7 +221,10 @@ def __new__( ) -> "Params4bit": if data is None: data = torch.empty(0) - + + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics From dd58310df17b69c63a9a06186e7f6bb24c98a199 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:37:28 +0530 Subject: [PATCH 48/98] Update modules.py --- bitsandbytes/nn/modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 6b6490265..2383f2c10 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,6 +11,7 @@ import torch.nn.functional as F import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( From 931bd70d868df8a663d32c3d4b410f72a45c1c3b Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 17:50:14 +0530 Subject: [PATCH 49/98] Update ops.py --- bitsandbytes/backends/cuda/ops.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 29dddc96e..fd7b7b9a2 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -303,7 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -381,7 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], From 9e62d466d226a62bd61e73afd676a694e1d13eac Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 30 May 2025 18:56:05 +0530 Subject: [PATCH 50/98] Update test_linear4bit.py --- tests/test_linear4bit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 474a00a1b..c241a265d 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -184,7 +184,7 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "cpu": @@ -209,7 +209,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "cpu": @@ -241,7 +241,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "cpu": From 1f71562a9ba57dd209f844549ffc8ff98bebb06d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 19:05:12 +0530 Subject: [PATCH 51/98] Update ops.py --- bitsandbytes/backends/cpu/ops.py | 93 ++++++++++++++++++++++++++------ 1 file changed, 76 insertions(+), 17 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d5ab9aa88..f58be5d2a 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -103,16 +103,39 @@ def _( n = A.numel() - # TODO: Support when weight matrix is not divisible by blocksize - torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") - - # Divide into blocks and normalize - blocks = A.reshape(-1, blocksize) - absmax = blocks.abs().max(dim=1).values.float() - scaled = blocks / absmax.unsqueeze(-1) - - # Quantize with the lookup table - quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) + blocks = n // blocksize + rem = n % blocksize + has_rem = rem > 0 + if has_rem: + blocks += 1 + + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_reshaped = A.reshape(n) + + if n >= blocksize: + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=1).values.float() + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].unsqueeze(-1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max().float() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + # Quantize with the lookup table + quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) + else: + blocks = A.reshape(-1, blocksize) + absmax = blocks.abs().max(dim=1).values.float() + scaled_A = blocks / absmax.unsqueeze(-1) + + # Quantize with the lookup table + quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) + + if quantized.numel() % 2 == 1: + quantized = torch.cat([quantized, torch.zeros((1, 1), device=A.device, dtype=torch.uint8)]) # Pack two quantized values per byte packed = quantized[::2] << 4 | quantized[1::2] @@ -149,16 +172,52 @@ def _( upper = (A >> 4).to(torch.int64) lower = (A & 0x0F).to(torch.int64) - # Expand to blocks - blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) + # Calculate the total number of elements in the original tensor + n = 1 + for d in shape: + n *= d + + # Concatenate upper and lower nibbles + indices = torch.cat((upper, lower), dim=1).reshape(-1) + + if indices.numel() > n: + indices = indices[:n] + + blocks = n // blocksize + rem = n % blocksize + has_rem = rem > 0 + if has_rem: + blocks += 1 + + if has_rem: + out = torch.empty(shape, dtype=dtype, device=A.device) + out_reshaped = out.reshape(-1) + + padded_indices = torch.zeros(blocks * blocksize, dtype=indices.dtype, device=indices.device) + padded_indices[:n] = indices + blocks_data = padded_indices.reshape(-1, blocksize) + + # Dequantize full blocks + dequantized = _NF4_QUANT_TABLE[blocks_data] + + # Apply scales to full blocks + out_reshaped[:n - rem] = ( + dequantized[:blocks - 1].reshape(-1, blocksize) * absmax[:blocks - 1].view(-1, 1) + ).reshape(-1) + + # Apply scale to remainder block + out_reshaped[n - rem:] = dequantized[blocks - 1, :rem] * absmax[-1] + else: + # Expand to blocks + blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) - # Dequantize - blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] + # Dequantize + blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] - # Reshape to original shape - blocks = blocks.reshape(-1, *shape[1:]) + # Reshape to original shape + out = blocks.reshape(-1, *shape[1:]) - return blocks.to(dtype) + return out.to(dtype) @register_kernel("bitsandbytes::gemv_4bit", "cpu") From eac7632e28043caad307cf2b5e1ff61fc9cbfe12 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:46:28 +0530 Subject: [PATCH 52/98] Update ops.py --- bitsandbytes/backends/cpu/ops.py | 93 ++++++-------------------------- 1 file changed, 17 insertions(+), 76 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index f58be5d2a..d5ab9aa88 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -103,39 +103,16 @@ def _( n = A.numel() - blocks = n // blocksize - rem = n % blocksize - has_rem = rem > 0 - if has_rem: - blocks += 1 - - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - A_reshaped = A.reshape(n) - - if n >= blocksize: - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[:blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=1).values.float() - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[:blocks - has_rem].unsqueeze(-1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max().float() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - - # Quantize with the lookup table - quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) - else: - blocks = A.reshape(-1, blocksize) - absmax = blocks.abs().max(dim=1).values.float() - scaled_A = blocks / absmax.unsqueeze(-1) - - # Quantize with the lookup table - quantized = torch.argmin(torch.abs(scaled_A.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) - - if quantized.numel() % 2 == 1: - quantized = torch.cat([quantized, torch.zeros((1, 1), device=A.device, dtype=torch.uint8)]) + # TODO: Support when weight matrix is not divisible by blocksize + torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") + + # Divide into blocks and normalize + blocks = A.reshape(-1, blocksize) + absmax = blocks.abs().max(dim=1).values.float() + scaled = blocks / absmax.unsqueeze(-1) + + # Quantize with the lookup table + quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) # Pack two quantized values per byte packed = quantized[::2] << 4 | quantized[1::2] @@ -172,52 +149,16 @@ def _( upper = (A >> 4).to(torch.int64) lower = (A & 0x0F).to(torch.int64) - # Calculate the total number of elements in the original tensor - n = 1 - for d in shape: - n *= d - - # Concatenate upper and lower nibbles - indices = torch.cat((upper, lower), dim=1).reshape(-1) - - if indices.numel() > n: - indices = indices[:n] - - blocks = n // blocksize - rem = n % blocksize - has_rem = rem > 0 - if has_rem: - blocks += 1 - - if has_rem: - out = torch.empty(shape, dtype=dtype, device=A.device) - out_reshaped = out.reshape(-1) - - padded_indices = torch.zeros(blocks * blocksize, dtype=indices.dtype, device=indices.device) - padded_indices[:n] = indices - blocks_data = padded_indices.reshape(-1, blocksize) - - # Dequantize full blocks - dequantized = _NF4_QUANT_TABLE[blocks_data] - - # Apply scales to full blocks - out_reshaped[:n - rem] = ( - dequantized[:blocks - 1].reshape(-1, blocksize) * absmax[:blocks - 1].view(-1, 1) - ).reshape(-1) - - # Apply scale to remainder block - out_reshaped[n - rem:] = dequantized[blocks - 1, :rem] * absmax[-1] - else: - # Expand to blocks - blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) + # Expand to blocks + blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) - # Dequantize - blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] + # Dequantize + blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] - # Reshape to original shape - out = blocks.reshape(-1, *shape[1:]) + # Reshape to original shape + blocks = blocks.reshape(-1, *shape[1:]) - return out.to(dtype) + return blocks.to(dtype) @register_kernel("bitsandbytes::gemv_4bit", "cpu") From 66dcfc407f59052fa9d5359cdebf619886100033 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 21:16:02 +0530 Subject: [PATCH 53/98] Update test_linear4bit.py --- tests/test_linear4bit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index c241a265d..1b7a7722c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -17,7 +17,6 @@ "float32": torch.float32, } -@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) From b96905d26c63355884e7decc65297591e108679d Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 21:17:02 +0530 Subject: [PATCH 54/98] Update test_linear4bit.py From ef31c362e22b201551605bc6d808026ea33da59c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 23:55:14 +0530 Subject: [PATCH 55/98] Update python-package.yml --- .github/workflows/python-package.yml | 643 ++++++++++++++------------- 1 file changed, 343 insertions(+), 300 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fbaa27d56..10daf0f79 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,303 +1,346 @@ -name: Python package - -on: - push: {} - pull_request: - branches: [main] - paths: - - ".github/workflows/python-package.yml" - - "bitsandbytes/**" - - "csrc/**" - - "include/**" - - "tests/**" - - "CMakeLists.txt" - - "requirements*.txt" - - "setup.py" - - "pyproject.toml" - release: - types: [published] - workflow_dispatch: {} # Allow manual trigger - workflow_call: {} # Allow triggering from other worfkflows - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true - -jobs: - ## - # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. - ## - build-shared-libs: - strategy: - matrix: - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cpu.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_${{ matrix.os }}_${{ matrix.arch }} - path: output/* - retention-days: 7 - ## - # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) - ## - build-shared-libs-cuda: - strategy: - fail-fast: false - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - cuda_version: - ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - # Windows: We install Cuda on the agent (slow) - - uses: Jimver/cuda-toolkit@v0.2.22 - if: startsWith(matrix.os, 'windows') - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda_version }} - method: "network" - sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' - linux-local-args: '["--toolkit"]' - use-github-cache: false - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cuda.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - cuda_version: ${{ matrix.cuda_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} - path: output/* - retention-days: 7 - - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - # The specific Python version is irrelevant in this context as we are only packaging non-C extension - # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is - # dictated by the packaged code itself, not the Python version used for packaging. - python-version: ["3.10"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Download build artifacts - uses: actions/download-artifact@v4 - with: - merge-multiple: true - pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" - path: output/ - - name: Copy correct platform shared library - shell: bash - run: | - ls -lR output/ - cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: pip - - run: pip install build wheel - - run: python -m build . - - name: Determine and Set Platform Tag, then Tag Wheel - shell: bash - run: | - PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") - echo "PLATFORM_TAG=$PLATFORM_TAG" - wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: dist/bitsandbytes-*.whl - retention-days: 7 - - upload-pre-release-wheels: - name: Create release and upload artifacts - runs-on: ubuntu-latest - if: github.ref_name == 'main' - permissions: - contents: write - needs: - - build-wheels - steps: - - name: Download and rename artifacts - uses: actions/download-artifact@v4 - with: - path: tmp/ - pattern: "bdist_wheel_*" - merge-multiple: true +name: Python package - - name: Inspect tmp directory after downloading artifacts - run: ls -alFR tmp/ +on: + push: {} + pull_request: + branches: [main] + paths: + - ".github/workflows/python-package.yml" + - "bitsandbytes/**" + - "csrc/**" + - "include/**" + - "tests/**" + - "CMakeLists.txt" + - "requirements*.txt" + - "setup.py" + - "pyproject.toml" + release: + types: [published] + workflow_dispatch: {} # Allow manual trigger + workflow_call: {} # Allow triggering from other worfkflows - - name: Move and rename wheel files with pattern replacement - run: | - mkdir -p wheels/ - - # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name - # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the - # wheel directly from the GH pre-release which gets updated continuously, e.g. - # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` - STABLE_PLACEHOLDER_VERSION="1.33.7.preview" - - # exclude macos wheels for now - find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do - wheel_filename=$(basename "$wheel") - - # Strip off the original version - rest=${wheel_filename#bitsandbytes-*-} - new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" - - echo "Renaming $wheel_filename → $new_name" - mv "$wheel" "wheels/${new_name}" - done - - - name: Inspect wheels directory after renaming files - run: ls -alFR wheels/ +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true - - name: Delete old pre-release (if exists) - run: | - gh release delete continuous-release_main --cleanup-tag -y || true - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Generate pip install commands for release body - run: | - cat > body.md << 'ENDOFMARKDOWN' - ## Latest `main` Wheel Pre-release - - This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. - - **How to install:** - Pick the correct command for your platform and run it in your terminal: - - ENDOFMARKDOWN - - for whl in wheels/*.whl; do - fname=$(basename "$whl") - url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" - echo "\`\`\`sh" >> body.md - echo "pip install $url" >> body.md - echo "\`\`\`" >> body.md - echo "" >> body.md - done - - cat >> body.md << 'ENDOFMARKDOWN' - > **Note:** - > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. - ENDOFMARKDOWN - - # for debugging: - cat body.md - - - name: Create new pre-release and upload artifacts - uses: softprops/action-gh-release@v2.2.1 - with: - files: wheels/*.whl - prerelease: true - name: Latest `main` wheel - body_path: body.md - tag_name: continuous-release_main - make_latest: false - draft: false - target_commitish: ${{ github.sha }} - - audit-wheels: - needs: build-wheels - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - runs-on: ${{ matrix.os }} - env: - PIP_DISABLE_PIP_VERSION_CHECK: 1 - steps: - - uses: actions/checkout@v4 - - name: Download wheel - uses: actions/download-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: wheels/ - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - run: pip install auditwheel - - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY - - publish-wheels: - name: Publish wheels to PyPI - needs: [build-wheels, audit-wheels] - runs-on: ubuntu-latest - if: | - github.repository == 'bitsandbytes-foundation/bitsandbytes' - && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - environment: - name: release - url: https://pypi.org/p/bitsandbytes - permissions: - id-token: write - steps: - - name: Download distribution artifacts - uses: actions/download-artifact@v4 - with: - path: dist/ - pattern: "bdist_wheel_*" - merge-multiple: true - - - name: Remove macOS wheels - run: rm dist/*macos* - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - print-hash: true +jobs: + ## + # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. + ## + build-shared-libs: + strategy: + matrix: + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cpu.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_${{ matrix.os }}_${{ matrix.arch }} + path: output/* + retention-days: 7 + ## + # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) + ## + build-shared-libs-cuda: + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + cuda_version: + ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + # Windows: We install Cuda on the agent (slow) + - uses: Jimver/cuda-toolkit@v0.2.22 + if: startsWith(matrix.os, 'windows') + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda_version }} + method: "network" + sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' + linux-local-args: '["--toolkit"]' + use-github-cache: false + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cuda.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + cuda_version: ${{ matrix.cuda_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} + path: output/* + retention-days: 7 + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + - build-shared-libs-rocm + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + # The specific Python version is irrelevant in this context as we are only packaging non-C extension + # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is + # dictated by the packaged code itself, not the Python version used for packaging. + python-version: ["3.10"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + merge-multiple: true + pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" + path: output/ + - name: Copy correct platform shared library + shell: bash + run: | + ls -lR output/ + cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - run: pip install build wheel + - run: python -m build . + - name: Determine and Set Platform Tag, then Tag Wheel + shell: bash + run: | + PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") + echo "PLATFORM_TAG=$PLATFORM_TAG" + wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: dist/bitsandbytes-*.whl + retention-days: 7 + + upload-pre-release-wheels: + name: Create release and upload artifacts + runs-on: ubuntu-latest + if: github.ref_name == 'main' + permissions: + contents: write + needs: + - build-wheels + steps: + - name: Download and rename artifacts + uses: actions/download-artifact@v4 + with: + path: tmp/ + pattern: "bdist_wheel_*" + merge-multiple: true + + - name: Inspect tmp directory after downloading artifacts + run: ls -alFR tmp/ + + - name: Move and rename wheel files with pattern replacement + run: | + mkdir -p wheels/ + + # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name + # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the + # wheel directly from the GH pre-release which gets updated continuously, e.g. + # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` + STABLE_PLACEHOLDER_VERSION="1.33.7.preview" + + # exclude macos wheels for now + find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do + wheel_filename=$(basename "$wheel") + + # Strip off the original version + rest=${wheel_filename#bitsandbytes-*-} + new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" + + echo "Renaming $wheel_filename → $new_name" + mv "$wheel" "wheels/${new_name}" + done + + - name: Inspect wheels directory after renaming files + run: ls -alFR wheels/ + + - name: Delete old pre-release (if exists) + run: | + gh release delete continuous-release_main --cleanup-tag -y || true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Generate pip install commands for release body + run: | + cat > body.md << 'ENDOFMARKDOWN' + ## Latest `main` Wheel Pre-release + + This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. + + **How to install:** + Pick the correct command for your platform and run it in your terminal: + + ENDOFMARKDOWN + + for whl in wheels/*.whl; do + fname=$(basename "$whl") + url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" + echo "\`\`\`sh" >> body.md + echo "pip install $url" >> body.md + echo "\`\`\`" >> body.md + echo "" >> body.md + done + + cat >> body.md << 'ENDOFMARKDOWN' + > **Note:** + > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. + ENDOFMARKDOWN + + # for debugging: + cat body.md + + - name: Create new pre-release and upload artifacts + uses: softprops/action-gh-release@v2.2.1 + with: + files: wheels/*.whl + prerelease: true + name: Latest `main` wheel + body_path: body.md + tag_name: continuous-release_main + make_latest: false + draft: false + target_commitish: ${{ github.sha }} + + audit-wheels: + needs: build-wheels + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + runs-on: ${{ matrix.os }} + env: + PIP_DISABLE_PIP_VERSION_CHECK: 1 + steps: + - uses: actions/checkout@v4 + - name: Download wheel + uses: actions/download-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: wheels/ + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install auditwheel + - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY + + publish-wheels: + name: Publish wheels to PyPI + needs: [build-wheels, audit-wheels] + runs-on: ubuntu-latest + if: | + github.repository == 'bitsandbytes-foundation/bitsandbytes' + && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + environment: + name: release + url: https://pypi.org/p/bitsandbytes + permissions: + id-token: write + steps: + - name: Download distribution artifacts + uses: actions/download-artifact@v4 + with: + path: dist/ + pattern: "bdist_wheel_*" + merge-multiple: true + + - name: Remove macOS wheels + run: rm dist/*macos* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + print-hash: true From e1435f01776137c3a253228b4234a23535532161 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Mon, 2 Jun 2025 23:57:25 +0530 Subject: [PATCH 56/98] Update python-package.yml --- .github/workflows/python-package.yml | 643 +++++++++++++-------------- 1 file changed, 300 insertions(+), 343 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 10daf0f79..fbaa27d56 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,346 +1,303 @@ -name: Python package +name: Python package + +on: + push: {} + pull_request: + branches: [main] + paths: + - ".github/workflows/python-package.yml" + - "bitsandbytes/**" + - "csrc/**" + - "include/**" + - "tests/**" + - "CMakeLists.txt" + - "requirements*.txt" + - "setup.py" + - "pyproject.toml" + release: + types: [published] + workflow_dispatch: {} # Allow manual trigger + workflow_call: {} # Allow triggering from other worfkflows + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + ## + # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. + ## + build-shared-libs: + strategy: + matrix: + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cpu.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_${{ matrix.os }}_${{ matrix.arch }} + path: output/* + retention-days: 7 + ## + # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) + ## + build-shared-libs-cuda: + strategy: + fail-fast: false + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + cuda_version: + ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + # Windows: We install Cuda on the agent (slow) + - uses: Jimver/cuda-toolkit@v0.2.22 + if: startsWith(matrix.os, 'windows') + id: cuda-toolkit + with: + cuda: ${{ matrix.cuda_version }} + method: "network" + sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' + linux-local-args: '["--toolkit"]' + use-github-cache: false + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl + - name: Build C++ + run: bash .github/scripts/build-cuda.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + cuda_version: ${{ matrix.cuda_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} + path: output/* + retention-days: 7 + + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + - os: windows-latest + arch: x86_64 + - os: macos-latest + arch: arm64 + # The specific Python version is irrelevant in this context as we are only packaging non-C extension + # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is + # dictated by the packaged code itself, not the Python version used for packaging. + python-version: ["3.10"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + merge-multiple: true + pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" + path: output/ + - name: Copy correct platform shared library + shell: bash + run: | + ls -lR output/ + cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - run: pip install build wheel + - run: python -m build . + - name: Determine and Set Platform Tag, then Tag Wheel + shell: bash + run: | + PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") + echo "PLATFORM_TAG=$PLATFORM_TAG" + wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: dist/bitsandbytes-*.whl + retention-days: 7 + + upload-pre-release-wheels: + name: Create release and upload artifacts + runs-on: ubuntu-latest + if: github.ref_name == 'main' + permissions: + contents: write + needs: + - build-wheels + steps: + - name: Download and rename artifacts + uses: actions/download-artifact@v4 + with: + path: tmp/ + pattern: "bdist_wheel_*" + merge-multiple: true -on: - push: {} - pull_request: - branches: [main] - paths: - - ".github/workflows/python-package.yml" - - "bitsandbytes/**" - - "csrc/**" - - "include/**" - - "tests/**" - - "CMakeLists.txt" - - "requirements*.txt" - - "setup.py" - - "pyproject.toml" - release: - types: [published] - workflow_dispatch: {} # Allow manual trigger - workflow_call: {} # Allow triggering from other worfkflows + - name: Inspect tmp directory after downloading artifacts + run: ls -alFR tmp/ -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true + - name: Move and rename wheel files with pattern replacement + run: | + mkdir -p wheels/ + + # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name + # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the + # wheel directly from the GH pre-release which gets updated continuously, e.g. + # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` + STABLE_PLACEHOLDER_VERSION="1.33.7.preview" + + # exclude macos wheels for now + find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do + wheel_filename=$(basename "$wheel") + + # Strip off the original version + rest=${wheel_filename#bitsandbytes-*-} + new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" + + echo "Renaming $wheel_filename → $new_name" + mv "$wheel" "wheels/${new_name}" + done + + - name: Inspect wheels directory after renaming files + run: ls -alFR wheels/ -jobs: - ## - # This job matrix builds the non-CUDA versions of the libraries for all supported platforms. - ## - build-shared-libs: - strategy: - matrix: - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cpu.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_${{ matrix.os }}_${{ matrix.arch }} - path: output/* - retention-days: 7 - ## - # This job matrix builds the CUDA versions of the libraries for platforms that support CUDA (Linux x64/aarch64 + Windows x64) - ## - build-shared-libs-cuda: - strategy: - fail-fast: false - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - cuda_version: - ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - # Windows: We install Cuda on the agent (slow) - - uses: Jimver/cuda-toolkit@v0.2.22 - if: startsWith(matrix.os, 'windows') - id: cuda-toolkit - with: - cuda: ${{ matrix.cuda_version }} - method: "network" - sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]' - linux-local-args: '["--toolkit"]' - use-github-cache: false - - name: Setup MSVC - if: startsWith(matrix.os, 'windows') - uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl - - name: Build C++ - run: bash .github/scripts/build-cuda.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - cuda_version: ${{ matrix.cuda_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_cuda_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.cuda_version }} - path: output/* - retention-days: 7 - build-shared-libs-rocm: - strategy: - matrix: - os: [ubuntu-22.04] - arch: [x86_64] - rocm_version: - ["6.1.2", "6.2.4", "6.3.2"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Set up Docker multiarch - uses: docker/setup-qemu-action@v3 - - name: Clean up disk space - run: | - sudo rm -rf \ - /usr/share/dotnet \ - /opt/ghc \ - "/usr/local/share/boost" \ - "$AGENT_TOOLSDIRECTORY" \ - /opt/hostedtoolcache \ - /opt/google/chrome \ - /opt/microsoft/msedge \ - /opt/microsoft/powershell \ - /opt/pipx \ - /usr/lib/mono \ - /usr/local/julia* \ - /usr/local/lib/android \ - /usr/local/lib/node_modules \ - /usr/local/share/chromium \ - /usr/local/share/powershell \ - /usr/share/swift - - name: Build C++ - run: bash .github/scripts/build-rocm.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - rocm_version: ${{ matrix.rocm_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} - path: output/* - retention-days: 7 - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda - - build-shared-libs-rocm - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - - os: windows-latest - arch: x86_64 - - os: macos-latest - arch: arm64 - # The specific Python version is irrelevant in this context as we are only packaging non-C extension - # code. This ensures compatibility across Python versions, including Python 3.9, as compatibility is - # dictated by the packaged code itself, not the Python version used for packaging. - python-version: ["3.10"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Download build artifacts - uses: actions/download-artifact@v4 - with: - merge-multiple: true - pattern: "shared_library*_${{ matrix.os }}_${{ matrix.arch }}*" - path: output/ - - name: Copy correct platform shared library - shell: bash - run: | - ls -lR output/ - cp output/${{ matrix.os }}/${{ matrix.arch }}/* bitsandbytes/ - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: pip - - run: pip install build wheel - - run: python -m build . - - name: Determine and Set Platform Tag, then Tag Wheel - shell: bash - run: | - PLATFORM_TAG=$(python .github/scripts/set_platform_tag.py "${{ matrix.arch }}") - echo "PLATFORM_TAG=$PLATFORM_TAG" - wheel tags --remove --abi-tag=none --python-tag=py3 --platform-tag=$PLATFORM_TAG dist/bitsandbytes-*.whl - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: dist/bitsandbytes-*.whl - retention-days: 7 - - upload-pre-release-wheels: - name: Create release and upload artifacts - runs-on: ubuntu-latest - if: github.ref_name == 'main' - permissions: - contents: write - needs: - - build-wheels - steps: - - name: Download and rename artifacts - uses: actions/download-artifact@v4 - with: - path: tmp/ - pattern: "bdist_wheel_*" - merge-multiple: true - - - name: Inspect tmp directory after downloading artifacts - run: ls -alFR tmp/ - - - name: Move and rename wheel files with pattern replacement - run: | - mkdir -p wheels/ - - # The whole point of the continuous release is to have a stable download link and the only way to have a PEP 440–compliant wheel name - # is to use a stable placeholder version. Otherwise, pip won't let you install the wheel. The cool thing is that we can now install the - # wheel directly from the GH pre-release which gets updated continuously, e.g. - # `pip install https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl` - STABLE_PLACEHOLDER_VERSION="1.33.7.preview" - - # exclude macos wheels for now - find tmp/ -type f -name '*.whl' ! -name '*macos*' -print0 | while IFS= read -r -d '' wheel; do - wheel_filename=$(basename "$wheel") - - # Strip off the original version - rest=${wheel_filename#bitsandbytes-*-} - new_name="bitsandbytes-${STABLE_PLACEHOLDER_VERSION}-${rest}" - - echo "Renaming $wheel_filename → $new_name" - mv "$wheel" "wheels/${new_name}" - done - - - name: Inspect wheels directory after renaming files - run: ls -alFR wheels/ - - - name: Delete old pre-release (if exists) - run: | - gh release delete continuous-release_main --cleanup-tag -y || true - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Generate pip install commands for release body - run: | - cat > body.md << 'ENDOFMARKDOWN' - ## Latest `main` Wheel Pre-release - - This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. - - **How to install:** - Pick the correct command for your platform and run it in your terminal: - - ENDOFMARKDOWN - - for whl in wheels/*.whl; do - fname=$(basename "$whl") - url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" - echo "\`\`\`sh" >> body.md - echo "pip install $url" >> body.md - echo "\`\`\`" >> body.md - echo "" >> body.md - done - - cat >> body.md << 'ENDOFMARKDOWN' - > **Note:** - > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. - ENDOFMARKDOWN - - # for debugging: - cat body.md - - - name: Create new pre-release and upload artifacts - uses: softprops/action-gh-release@v2.2.1 - with: - files: wheels/*.whl - prerelease: true - name: Latest `main` wheel - body_path: body.md - tag_name: continuous-release_main - make_latest: false - draft: false - target_commitish: ${{ github.sha }} - - audit-wheels: - needs: build-wheels - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-22.04-arm] - include: - - os: ubuntu-22.04 - arch: x86_64 - - os: ubuntu-22.04-arm - arch: aarch64 - runs-on: ${{ matrix.os }} - env: - PIP_DISABLE_PIP_VERSION_CHECK: 1 - steps: - - uses: actions/checkout@v4 - - name: Download wheel - uses: actions/download-artifact@v4 - with: - name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} - path: wheels/ - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - run: pip install auditwheel - - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY - - publish-wheels: - name: Publish wheels to PyPI - needs: [build-wheels, audit-wheels] - runs-on: ubuntu-latest - if: | - github.repository == 'bitsandbytes-foundation/bitsandbytes' - && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - environment: - name: release - url: https://pypi.org/p/bitsandbytes - permissions: - id-token: write - steps: - - name: Download distribution artifacts - uses: actions/download-artifact@v4 - with: - path: dist/ - pattern: "bdist_wheel_*" - merge-multiple: true - - - name: Remove macOS wheels - run: rm dist/*macos* - - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - print-hash: true + - name: Delete old pre-release (if exists) + run: | + gh release delete continuous-release_main --cleanup-tag -y || true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Generate pip install commands for release body + run: | + cat > body.md << 'ENDOFMARKDOWN' + ## Latest `main` Wheel Pre-release + + This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. + + **How to install:** + Pick the correct command for your platform and run it in your terminal: + + ENDOFMARKDOWN + + for whl in wheels/*.whl; do + fname=$(basename "$whl") + url="https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/$fname" + echo "\`\`\`sh" >> body.md + echo "pip install $url" >> body.md + echo "\`\`\`" >> body.md + echo "" >> body.md + done + + cat >> body.md << 'ENDOFMARKDOWN' + > **Note:** + > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. + ENDOFMARKDOWN + + # for debugging: + cat body.md + + - name: Create new pre-release and upload artifacts + uses: softprops/action-gh-release@v2.2.1 + with: + files: wheels/*.whl + prerelease: true + name: Latest `main` wheel + body_path: body.md + tag_name: continuous-release_main + make_latest: false + draft: false + target_commitish: ${{ github.sha }} + + audit-wheels: + needs: build-wheels + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-22.04-arm] + include: + - os: ubuntu-22.04 + arch: x86_64 + - os: ubuntu-22.04-arm + arch: aarch64 + runs-on: ${{ matrix.os }} + env: + PIP_DISABLE_PIP_VERSION_CHECK: 1 + steps: + - uses: actions/checkout@v4 + - name: Download wheel + uses: actions/download-artifact@v4 + with: + name: bdist_wheel_${{ matrix.os }}_${{ matrix.arch }} + path: wheels/ + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - run: pip install auditwheel + - run: python ./.github/scripts/auditwheel_show.py wheels/* | tee $GITHUB_STEP_SUMMARY + + publish-wheels: + name: Publish wheels to PyPI + needs: [build-wheels, audit-wheels] + runs-on: ubuntu-latest + if: | + github.repository == 'bitsandbytes-foundation/bitsandbytes' + && github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + environment: + name: release + url: https://pypi.org/p/bitsandbytes + permissions: + id-token: write + steps: + - name: Download distribution artifacts + uses: actions/download-artifact@v4 + with: + path: dist/ + pattern: "bdist_wheel_*" + merge-multiple: true + + - name: Remove macOS wheels + run: rm dist/*macos* + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + print-hash: true From da9a271446295e012cd61263836ab8fea0a06af8 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 3 Jun 2025 00:06:56 +0530 Subject: [PATCH 57/98] Update python-package.yml --- .github/workflows/python-package.yml | 53 +++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fbaa27d56..8b0bbb374 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -102,10 +102,55 @@ jobs: path: output/* retention-days: 7 - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 + + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + - build-shared-libs-rocm strategy: matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] From 08848daddb2ec6bd13f7b5a0720bd6d34988d818 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 3 Jun 2025 00:12:54 +0530 Subject: [PATCH 58/98] Update python-package.yml --- .github/workflows/python-package.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 8b0bbb374..a65d0f5bb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -145,12 +145,12 @@ jobs: name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} path: output/* retention-days: 7 - - build-wheels: - needs: - - build-shared-libs - - build-shared-libs-cuda - - build-shared-libs-rocm + + build-wheels: + needs: + - build-shared-libs + - build-shared-libs-cuda + - build-shared-libs-rocm strategy: matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] From 978cba3825e3624bc39d594a2bd01c2444e1af69 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Tue, 3 Jun 2025 01:33:00 +0530 Subject: [PATCH 59/98] Create build-rocm.sh --- .github/scripts/build-rocm.sh | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 .github/scripts/build-rocm.sh diff --git a/.github/scripts/build-rocm.sh b/.github/scripts/build-rocm.sh new file mode 100644 index 000000000..b508fac69 --- /dev/null +++ b/.github/scripts/build-rocm.sh @@ -0,0 +1,21 @@ +#!/bin/bash +declare build_arch +declare build_os +declare rocm_version + +set -xeuo pipefail +bnb_rocm_arch="gfx90a;gfx942;gfx1100" +if [ "${build_os:0:6}" == ubuntu ]; then + image=rocm/dev-ubuntu-22.04:${rocm_version}-complete + echo "Using image $image" + docker run --rm --platform "linux/$build_arch" -i \ + -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ + && cmake --build ." +fi + +output_dir="output/${build_os}/${build_arch}" +mkdir -p "${output_dir}" +(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") From af6561aec6d7df66f58d4f667e1f1307aef57011 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 4 Jun 2025 00:34:30 +0530 Subject: [PATCH 60/98] Update cuda_specs.py --- bitsandbytes/cuda_specs.py | 48 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 61d03083c..bbdf457cc 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,6 +1,6 @@ import dataclasses -import logging -import re +import logging +import re import subprocess from functools import lru_cache from typing import Optional @@ -78,25 +78,25 @@ def get_cuda_specs() -> Optional[CUDASpecs]: return None -def get_rocm_gpu_arch() -> str: - """Get ROCm GPU architecture.""" - logger = logging.getLogger(__name__) - try: - if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) - if match: - return "gfx" + match.group(1) - else: - return "unknown" - else: - return "unknown" - except Exception as e: - logger.error(f"Could not detect ROCm GPU architecture: {e}") - if torch.cuda.is_available(): - logger.warning( - """ -ROCm GPU architecture detection failed despite ROCm being available. - """, - ) - return "unknown" +def get_rocm_gpu_arch() -> str: + """Get ROCm GPU architecture.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" From 405b4843fe2dffc0ab8059f82a4e3fb399ed10f0 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 4 Jun 2025 00:54:11 +0530 Subject: [PATCH 61/98] Fix trailing whitespace --- .github/workflows/python-package.yml | 96 +++---- bitsandbytes/backends/cuda/ops.py | 36 +-- bitsandbytes/cextension.py | 16 +- bitsandbytes/cuda_specs.py | 2 +- bitsandbytes/diagnostics/cuda.py | 12 +- bitsandbytes/diagnostics/main.py | 3 +- bitsandbytes/functional.py | 10 +- bitsandbytes/nn/modules.py | 4 +- conflicts.diff | 382 +++++++++++++++++++++++++++ csrc/common_hip.cuh | 2 +- csrc/kernels.hip | 26 +- csrc/ops.hip | 10 +- tests/test_cuda_setup_evaluator.py | 2 + tests/test_functional.py | 15 +- tests/test_linear4bit.py | 1 + tests/test_ops.py | 2 +- 16 files changed, 506 insertions(+), 113 deletions(-) create mode 100644 conflicts.diff diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index a65d0f5bb..3673ac608 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -102,49 +102,49 @@ jobs: path: output/* retention-days: 7 - build-shared-libs-rocm: - strategy: - matrix: - os: [ubuntu-22.04] - arch: [x86_64] - rocm_version: - ["6.1.2", "6.2.4", "6.3.2"] - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v4 - - name: Set up Docker multiarch - uses: docker/setup-qemu-action@v3 - - name: Clean up disk space - run: | - sudo rm -rf \ - /usr/share/dotnet \ - /opt/ghc \ - "/usr/local/share/boost" \ - "$AGENT_TOOLSDIRECTORY" \ - /opt/hostedtoolcache \ - /opt/google/chrome \ - /opt/microsoft/msedge \ - /opt/microsoft/powershell \ - /opt/pipx \ - /usr/lib/mono \ - /usr/local/julia* \ - /usr/local/lib/android \ - /usr/local/lib/node_modules \ - /usr/local/share/chromium \ - /usr/local/share/powershell \ - /usr/share/swift - - name: Build C++ - run: bash .github/scripts/build-rocm.sh - env: - build_os: ${{ matrix.os }} - build_arch: ${{ matrix.arch }} - rocm_version: ${{ matrix.rocm_version }} - - name: Upload build artifact - uses: actions/upload-artifact@v4 - with: - name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} - path: output/* - retention-days: 7 + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 build-wheels: needs: @@ -216,10 +216,10 @@ jobs: path: tmp/ pattern: "bdist_wheel_*" merge-multiple: true - + - name: Inspect tmp directory after downloading artifacts run: ls -alFR tmp/ - + - name: Move and rename wheel files with pattern replacement run: | mkdir -p wheels/ @@ -244,7 +244,7 @@ jobs: - name: Inspect wheels directory after renaming files run: ls -alFR wheels/ - + - name: Delete old pre-release (if exists) run: | gh release delete continuous-release_main --cleanup-tag -y || true @@ -258,7 +258,7 @@ jobs: This pre-release contains the latest development wheels for all supported platforms, rebuilt automatically on every commit to the `main` branch. - **How to install:** + **How to install:** Pick the correct command for your platform and run it in your terminal: ENDOFMARKDOWN @@ -273,7 +273,7 @@ jobs: done cat >> body.md << 'ENDOFMARKDOWN' - > **Note:** + > **Note:** > These wheels are updated automatically with every commit to `main` and become available as soon as the [python-package.yml](.github/workflows/python-package.yml) workflow finishes. ENDOFMARKDOWN diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index fd7b7b9a2..9089d6fc2 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib, HIP_ENVIRONMENT +from ...cextension import HIP_ENVIRONMENT, lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -210,12 +210,12 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() @@ -269,11 +269,11 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], @@ -303,11 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -385,11 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - if HIP_ENVIRONMENT: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) - else: + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) - + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 108aa0c9a..5283df93e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -81,7 +81,7 @@ def get_available_cuda_binary_versions() -> list[str]: lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" versions = [] for lib in Path(__file__).parent.glob(lib_pattern): - pattern = r"{}(\d+)".format(BNB_BACKEND.lower()) + pattern = rf"{BNB_BACKEND.lower()}(\d+)" match = re.search(pattern, lib.name) if match: ver_code = int(match.group(1)) @@ -199,18 +199,16 @@ def _format_lib_error_message( ) compile_instructions = ( - ( - "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" - ) if not no_cuda_lib_found - else - ( + ("COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n") + if not no_cuda_lib_found + else ( "You have two options:\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" - ) if not HIP_ENVIRONMENT - else - ( + ) + if not HIP_ENVIRONMENT + else ( "You can COMPILE FROM SOURCE as mentioned here:\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n" ) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index bbdf457cc..32563a159 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,8 +1,8 @@ import dataclasses +from functools import lru_cache import logging import re import subprocess -from functools import lru_cache from typing import Optional import torch diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index b9de27fd7..b9db101ab 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -33,11 +33,13 @@ } CUDA_RUNTIME_LIB_PATTERNS = ( - "libamdhip64.so*", -) if HIP_ENVIRONMENT else ( - "cudart64*.dll", # Windows - "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. - "nvcuda*.dll", # Windows + ("libamdhip64.so*",) + if HIP_ENVIRONMENT + else ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) ) logger = logging.getLogger(__name__) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 8e2bc2a7b..bf31d7978 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -43,7 +43,8 @@ def main(): print(f"{BNB_BACKEND} specs:{cuda_specs}") if not torch.cuda.is_available(): print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") + if not HIP_ENVIRONMENT: + print(f"- {BNB_BACKEND} driver not installed") print(f"- {BNB_BACKEND} not installed") print(f"- You have multiple conflicting {BNB_BACKEND} libraries") if cuda_specs: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 03f6c323d..9b7ce2da9 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib, HIP_ENVIRONMENT +from .cextension import HIP_ENVIRONMENT, lib name2qmap = {} @@ -1007,10 +1007,10 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -1114,10 +1114,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + if quant_state is None: assert absmax is not None and out is not None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2383f2c10..a2facac28 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -222,10 +222,10 @@ def __new__( ) -> "Params4bit": if data is None: data = torch.empty(0) - + if blocksize is None: blocksize = 64 if not HIP_ENVIRONMENT else 128 - + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics diff --git a/conflicts.diff b/conflicts.diff new file mode 100644 index 000000000..cab8c6ea7 --- /dev/null +++ b/conflicts.diff @@ -0,0 +1,382 @@ +diff --cc bitsandbytes/cextension.py +index 108aa0c,b112df2..0000000 +--- a/bitsandbytes/cextension.py ++++ b/bitsandbytes/cextension.py +@@@ -28,17 -28,10 +29,15 @@@ def get_cuda_bnb_library_path(cuda_spec + override_value = os.environ.get("BNB_CUDA_VERSION") + if override_value: + library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) + + if torch.version.hip: + + raise RuntimeError( + + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" + + ) + logger.warning( + f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" +- "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" ++ "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" + "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" +- "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" +- "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLi + return BNBNativeLibrary(dll) + + + +ROCM_GPU_ARCH = get_rocm_gpu_arch() + + + try: +++<<<<<<< HEAD + + if torch.version.hip: + + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" + + else: + + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" + + +++======= ++ # to support Intel CPU/GPU (XPU) backend ++ import intel_extension_for_pytorch as ipex ++ ++ ipex_cpu = ipex if ipex._C._has_cpu() else None ++ ipex_xpu = ipex if ipex._C._has_xpu() else None ++ except BaseException: ++ ipex_cpu = None ++ ipex_xpu = None ++ ++ ++ try: +++>>>>>>> upstream/main + lib = get_native_library() + except Exception as e: + error_msg = str(e) +diff --cc bitsandbytes/diagnostics/cuda.py +index b9de27f,e763ef2..0000000 +--- a/bitsandbytes/diagnostics/cuda.py ++++ b/bitsandbytes/diagnostics/cuda.py +@@@ -5,8 -5,7 +5,12 @@@ from pathlib import Pat + + import torch + +++<<<<<<< HEAD + +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path + +from bitsandbytes.consts import NONPYTORCH_DOC_URL +++======= ++ from bitsandbytes.cextension import get_cuda_bnb_library_path +++>>>>>>> upstream/main + from bitsandbytes.cuda_specs import CUDASpecs + from bitsandbytes.diagnostics.utils import print_dedented + +@@@ -146,42 -127,8 +134,38 @@@ def _print_cuda_diagnostics(cuda_specs + """, + ) + +- # TODO: +- # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) +- # (2) Multiple CUDA versions installed +- + + -def print_cuda_runtime_diagnostics() -> None: + +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: + + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + + + + binary_path = get_cuda_bnb_library_path(cuda_specs) + + if not binary_path.exists(): + + print_dedented( + + f""" + + Library not found: {binary_path}. + + Maybe you need to compile it from source? If you compiled from source, check that ROCm version + + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + + and rebuild bitsandbytes. + + """, + + ) + + + + hip_major, hip_minor = cuda_specs.cuda_version_tuple + + if (hip_major, hip_minor) < (6, 1): + + print_dedented( + + """ + + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + + """, + + ) + + + + + +def print_diagnostics(cuda_specs: CUDASpecs) -> None: + + if HIP_ENVIRONMENT: + + _print_hip_diagnostics(cuda_specs) + + else: + + _print_cuda_diagnostics(cuda_specs) + + + + + +def _print_cuda_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") +diff --cc bitsandbytes/diagnostics/main.py +index 8e2bc2a,aa4cb30..0000000 +--- a/bitsandbytes/diagnostics/main.py ++++ b/bitsandbytes/diagnostics/main.py +@@@ -3,12 -5,11 +5,20 @@@ import tracebac + + import torch + +++<<<<<<< HEAD + +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT + +from bitsandbytes.consts import PACKAGE_GITHUB_URL + +from bitsandbytes.cuda_specs import get_cuda_specs + +from bitsandbytes.diagnostics.cuda import ( + + print_diagnostics, + + print_runtime_diagnostics, +++======= ++ from bitsandbytes import __version__ as bnb_version ++ from bitsandbytes.consts import PACKAGE_GITHUB_URL ++ from bitsandbytes.cuda_specs import get_cuda_specs ++ from bitsandbytes.diagnostics.cuda import ( ++ print_cuda_diagnostics, +++>>>>>>> upstream/main + ) + from bitsandbytes.diagnostics.utils import print_dedented, print_header + +@@@ -28,52 -41,77 +50,122 @@@ def sanity_check() + assert p1 != p2 + + ++ def get_package_version(name: str) -> str: ++ try: ++ version = importlib.metadata.version(name) ++ except importlib.metadata.PackageNotFoundError: ++ version = "not found" ++ return version ++ ++ ++ def show_environment(): ++ """Simple utility to print out environment information.""" ++ ++ print(f"Platform: {platform.platform()}") ++ if platform.system() == "Linux": ++ print(f" libc: {'-'.join(platform.libc_ver())}") ++ ++ print(f"Python: {platform.python_version()}") ++ ++ print(f"PyTorch: {torch.__version__}") ++ print(f" CUDA: {torch.version.cuda or 'N/A'}") ++ print(f" HIP: {torch.version.hip or 'N/A'}") ++ print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") ++ ++ print("Related packages:") ++ for pkg in _RELATED_PACKAGES: ++ version = get_package_version(pkg) ++ print(f" {pkg}: {version}") ++ ++ + def main(): +- print_header("") +- print_header("BUG REPORT INFORMATION") ++ print_header(f"bitsandbytes v{bnb_version}") ++ show_environment() + print_header("") + +- print_header("OTHER") + cuda_specs = get_cuda_specs() +++<<<<<<< HEAD + + if HIP_ENVIRONMENT: + + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," + + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + + print(f"{BNB_BACKEND} specs:{rocm_specs}") + + else: + + print(f"{BNB_BACKEND} specs:{cuda_specs}") + + if not torch.cuda.is_available(): + + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") + + if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") + + print(f"- {BNB_BACKEND} not installed") + + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") + + if cuda_specs: + + print_diagnostics(cuda_specs) + + print_runtime_diagnostics() + + print_header("") + + print_header("DEBUG INFO END") + + print_header("") + + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") + + try: + + sanity_check() + + print("SUCCESS!") + + print("Installation was successful!") + + return + + except RuntimeError as e: + + if "not available in CPU-only" in str(e): + + print( + + f"WARNING: {__package__} is currently running as CPU-only!\n" + + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + + f"If you think that this is so erroneously,\nplease report an issue!", + + ) + + else: + + raise e + + except Exception: + + traceback.print_exc() + + print_dedented( + + f""" + + Above we output some debug information. + + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + + """, + + ) + + sys.exit(1) +++======= ++ ++ if cuda_specs: ++ print_cuda_diagnostics(cuda_specs) ++ ++ # TODO: There's a lot of noise in this; needs improvement. ++ # print_cuda_runtime_diagnostics() ++ ++ if not torch.cuda.is_available(): ++ print("PyTorch says CUDA is not available. Possible reasons:") ++ print("1. CUDA driver not installed") ++ print("2. Using a CPU-only PyTorch build") ++ print("3. No GPU detected") ++ ++ else: ++ print("Checking that the library is importable and CUDA is callable...") ++ ++ try: ++ sanity_check() ++ print("SUCCESS!") ++ return ++ except RuntimeError as e: ++ if "not available in CPU-only" in str(e): ++ print( ++ f"WARNING: {__package__} is currently running as CPU-only!\n" ++ "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" ++ f"If you think that this is so erroneously,\nplease report an issue!", ++ ) ++ else: ++ raise e ++ except Exception: ++ traceback.print_exc() ++ ++ print_dedented( ++ f""" ++ Above we output some debug information. ++ Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ++ WARNING: Please be sure to sanitize sensitive info from the output before posting it. ++ """, ++ ) ++ sys.exit(1) +++>>>>>>> upstream/main +diff --cc bitsandbytes/functional.py +index 03f6c32,ffb6668..0000000 +mode 100644,100755..100755 +--- a/bitsandbytes/functional.py ++++ b/bitsandbytes/functional.py +@@@ -13,9 -13,9 +13,13 @@@ import torc + from torch import Tensor + from typing_extensions import deprecated + +- from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict ++ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict + +++<<<<<<< HEAD + +from .cextension import lib, HIP_ENVIRONMENT +++======= ++ from .cextension import ipex_cpu, ipex_xpu, lib +++>>>>>>> upstream/main + + name2qmap = {} + +diff --cc bitsandbytes/nn/modules.py +index 2383f2c,ccd842c..0000000 +--- a/bitsandbytes/nn/modules.py ++++ b/bitsandbytes/nn/modules.py +@@@ -11,8 -11,7 +11,12 @@@ from torch import Tensor, device, dtype + import torch.nn.functional as F + + import bitsandbytes as bnb +++<<<<<<< HEAD + +from bitsandbytes.cextension import HIP_ENVIRONMENT + +from bitsandbytes.functional import QuantState +++======= ++ from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu +++>>>>>>> upstream/main + from bitsandbytes.optim import GlobalOptimManager + from bitsandbytes.utils import ( + INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, +diff --cc tests/test_linear4bit.py +index 1b7a772,b5db2eb..0000000 +--- a/tests/test_linear4bit.py ++++ b/tests/test_linear4bit.py +@@@ -7,8 -8,14 +8,19 @@@ import pytes + import torch + + import bitsandbytes as bnb +++<<<<<<< HEAD + +from bitsandbytes.cextension import HIP_ENVIRONMENT + +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer +++======= ++ from tests.helpers import ( ++ TRUE_FALSE, ++ describe_dtype, ++ get_available_devices, ++ id_formatter, ++ torch_load_from_buffer, ++ torch_save_to_buffer, ++ ) +++>>>>>>> upstream/main + + storage = { + "uint8": torch.uint8, +@@@ -183,16 -185,10 +189,10 @@@ def test_linear_serialization(device, q + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) + -@pytest.mark.parametrize("blocksize", [64, 128]) + +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + def test_copy_param(device, quant_type, blocksize, compress_statistics): +- if device == "cpu": +- if compress_statistics: +- pytest.skip("Currently segfaults on CPU") +- if quant_type == "fp4": +- pytest.xfail("FP4 not supported on CPU") +- +- tensor = torch.linspace(1, blocksize, blocksize) ++ tensor = torch.randn(300, 400) + param = bnb.nn.Params4bit( + data=tensor, + quant_type=quant_type, +@@@ -208,16 -204,10 +208,10 @@@ + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) + -@pytest.mark.parametrize("blocksize", [64, 128]) + +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): +- if device == "cpu": +- if compress_statistics: +- pytest.skip("Currently segfaults on CPU") +- if quant_type == "fp4": +- pytest.xfail("FP4 not supported on CPU") +- +- tensor = torch.linspace(1, blocksize, blocksize) ++ tensor = torch.randn(300, 400) + param = bnb.nn.Params4bit( + data=tensor, + quant_type=quant_type, +@@@ -240,16 -230,10 +234,10 @@@ + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) + -@pytest.mark.parametrize("blocksize", [64, 128]) + +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) + @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) + def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): +- if device == "cpu": +- if compress_statistics: +- pytest.skip("Currently segfaults on CPU") +- if quant_type == "fp4": +- pytest.xfail("FP4 not supported on CPU") +- +- original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) ++ original_tensor = torch.randn(300, 400) + original_param = bnb.nn.Params4bit( + data=original_tensor, + quant_type=quant_type, diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh index e7fc4eb81..105179535 100644 --- a/csrc/common_hip.cuh +++ b/csrc/common_hip.cuh @@ -1,6 +1,6 @@ #pragma once -#define BNB_WARP_SIZE warpSize +#define BNB_WARP_SIZE warpSize // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs #define BNB_MAX_THREADS_PER_SM 2048 diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 368788f39..56e1d54db 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -532,7 +532,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float absmax[i / BLOCK_SIZE] = local_abs_max; } __syncthreads(); - + local_abs_max = smem_absmax_value[0]; if(STOCHASTIC) @@ -610,7 +610,7 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs valid_items_load = min(TILE_SIZE, n - i); valid_items_store = valid_items_load; } - + // Since blocksize will always be a power-of-2, we avoid more expensive // division by the blocksize and instead use a shift operation. // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. @@ -811,7 +811,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p, LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); __syncthreads(); Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); - + // Load additional state1 data for AdEMAMix // TODO: Make constexpr after updating min compiler if (OPTIMIZER == ADEMAMIX) { @@ -1607,7 +1607,7 @@ kOptimizerStatic8bit2StateBlockwise( unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; unsigned char c3s[N_PER_TH]; - + T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; typedef hipcub::BlockLoad LoadT; @@ -1712,7 +1712,7 @@ kOptimizerStatic8bit2StateBlockwise( new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); - + if (OPTIMIZER == ADEMAMIX) { new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); } @@ -1776,7 +1776,7 @@ kOptimizerStatic8bit2StateBlockwise( } else { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); } - + if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -2148,27 +2148,27 @@ __global__ void kdequant_mm_int32_fp16( int local_values[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD]; - + float local_rowStats[ITEMS_PER_THREAD]; float local_colStats[ITEMS_PER_THREAD]; float local_biasValue[ITEMS_PER_THREAD]; typedef hipcub::BlockLoad LoadInt32; __shared__ typename LoadInt32::TempStorage loadint32; - + int row_idx, col_idx; - + #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) { row_idx = (block_offset + thread_offset + j) / numCols; col_idx = (block_offset + thread_offset + j) % numCols; - + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; - local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); } - + // Each block loads THREADS * ITEMS_PER_THREAD values from A int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD @@ -2188,7 +2188,7 @@ __global__ void kdequant_mm_int32_fp16( if (outIdx < n_out) { out[outIdx] = local_output[j]; } - } + } } #define DENORM 1.0f/127.0f diff --git a/csrc/ops.hip b/csrc/ops.hip index 4d077d19a..eef616d48 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -199,10 +199,10 @@ template void optimizerStatic8bit(T* p, T* g, } } -#define BLOCKSIZE_2STATE 256 -#define NUM_2STATE 1 -#define BLOCKSIZE_1STATE 256 -#define NUM_1STATE 1 +#define BLOCKSIZE_2STATE 256 +#define NUM_2STATE 1 +#define BLOCKSIZE_1STATE 256 +#define NUM_1STATE 1 template void optimizerStatic8bitBlockwise( T* p, @@ -443,7 +443,7 @@ static std::string hipError_to_string(const hipError_t ret) } template int igemmlt( - hipblasLtHandle_t ltHandle, + hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 1b2ea85db..3d8b688ee 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -12,11 +12,13 @@ def cuda120_spec() -> CUDASpecs: cuda_version_tuple=(12, 0), ) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") diff --git a/tests/test_functional.py b/tests/test_functional.py index 5f5ee488c..a2964c733 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -8,8 +8,8 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from bitsandbytes import functional as F +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -92,7 +92,10 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128] ) + @pytest.mark.parametrize( + "blocksize", + [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128], + ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -796,6 +799,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): A[:, outlier_cols] = 0 torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSpMMFunctional: @@ -1106,7 +1110,10 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize( + "blocksize", + [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], + ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1205,7 +1212,7 @@ def test_bench_4bit_dequant(self, quant_type): # torch.matmul(b, a.t()) # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) - + @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 1b7a7722c..60c163477 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -17,6 +17,7 @@ "float32": torch.float32, } + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) diff --git a/tests/test_ops.py b/tests/test_ops.py index a99d080b3..a433a0c4b 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,8 +4,8 @@ import torch import bitsandbytes -from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter from bitsandbytes.cextension import HIP_ENVIRONMENT +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter class TestLLMInt8Ops: From 93768d07b1b753790a784f1472e5b6b1f9fa5c73 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 4 Jun 2025 01:24:09 +0530 Subject: [PATCH 62/98] Remove conflicts.diff --- conflicts.diff | 382 ------------------------------------------------- 1 file changed, 382 deletions(-) delete mode 100644 conflicts.diff diff --git a/conflicts.diff b/conflicts.diff deleted file mode 100644 index cab8c6ea7..000000000 --- a/conflicts.diff +++ /dev/null @@ -1,382 +0,0 @@ -diff --cc bitsandbytes/cextension.py -index 108aa0c,b112df2..0000000 ---- a/bitsandbytes/cextension.py -+++ b/bitsandbytes/cextension.py -@@@ -28,17 -28,10 +29,15 @@@ def get_cuda_bnb_library_path(cuda_spec - override_value = os.environ.get("BNB_CUDA_VERSION") - if override_value: - library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) - + if torch.version.hip: - + raise RuntimeError( - + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" - + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" - + ) - logger.warning( - f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" -- "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" -+ "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" - "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" -- "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" -- "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLi - return BNBNativeLibrary(dll) - - - +ROCM_GPU_ARCH = get_rocm_gpu_arch() - + - try: -++<<<<<<< HEAD - + if torch.version.hip: - + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" - + else: - + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" - + -++======= -+ # to support Intel CPU/GPU (XPU) backend -+ import intel_extension_for_pytorch as ipex -+ -+ ipex_cpu = ipex if ipex._C._has_cpu() else None -+ ipex_xpu = ipex if ipex._C._has_xpu() else None -+ except BaseException: -+ ipex_cpu = None -+ ipex_xpu = None -+ -+ -+ try: -++>>>>>>> upstream/main - lib = get_native_library() - except Exception as e: - error_msg = str(e) -diff --cc bitsandbytes/diagnostics/cuda.py -index b9de27f,e763ef2..0000000 ---- a/bitsandbytes/diagnostics/cuda.py -+++ b/bitsandbytes/diagnostics/cuda.py -@@@ -5,8 -5,7 +5,12 @@@ from pathlib import Pat - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path - +from bitsandbytes.consts import NONPYTORCH_DOC_URL -++======= -+ from bitsandbytes.cextension import get_cuda_bnb_library_path -++>>>>>>> upstream/main - from bitsandbytes.cuda_specs import CUDASpecs - from bitsandbytes.diagnostics.utils import print_dedented - -@@@ -146,42 -127,8 +134,38 @@@ def _print_cuda_diagnostics(cuda_specs - """, - ) - -- # TODO: -- # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) -- # (2) Multiple CUDA versions installed -- - - -def print_cuda_runtime_diagnostics() -> None: - +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: - + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") - + - + binary_path = get_cuda_bnb_library_path(cuda_specs) - + if not binary_path.exists(): - + print_dedented( - + f""" - + Library not found: {binary_path}. - + Maybe you need to compile it from source? If you compiled from source, check that ROCm version - + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version - + and rebuild bitsandbytes. - + """, - + ) - + - + hip_major, hip_minor = cuda_specs.cuda_version_tuple - + if (hip_major, hip_minor) < (6, 1): - + print_dedented( - + """ - + WARNING: bitsandbytes is fully supported only from ROCm 6.1. - + """, - + ) - + - + - +def print_diagnostics(cuda_specs: CUDASpecs) -> None: - + if HIP_ENVIRONMENT: - + _print_hip_diagnostics(cuda_specs) - + else: - + _print_cuda_diagnostics(cuda_specs) - + - + - +def _print_cuda_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: - print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") -diff --cc bitsandbytes/diagnostics/main.py -index 8e2bc2a,aa4cb30..0000000 ---- a/bitsandbytes/diagnostics/main.py -+++ b/bitsandbytes/diagnostics/main.py -@@@ -3,12 -5,11 +5,20 @@@ import tracebac - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT - +from bitsandbytes.consts import PACKAGE_GITHUB_URL - +from bitsandbytes.cuda_specs import get_cuda_specs - +from bitsandbytes.diagnostics.cuda import ( - + print_diagnostics, - + print_runtime_diagnostics, -++======= -+ from bitsandbytes import __version__ as bnb_version -+ from bitsandbytes.consts import PACKAGE_GITHUB_URL -+ from bitsandbytes.cuda_specs import get_cuda_specs -+ from bitsandbytes.diagnostics.cuda import ( -+ print_cuda_diagnostics, -++>>>>>>> upstream/main - ) - from bitsandbytes.diagnostics.utils import print_dedented, print_header - -@@@ -28,52 -41,77 +50,122 @@@ def sanity_check() - assert p1 != p2 - - -+ def get_package_version(name: str) -> str: -+ try: -+ version = importlib.metadata.version(name) -+ except importlib.metadata.PackageNotFoundError: -+ version = "not found" -+ return version -+ -+ -+ def show_environment(): -+ """Simple utility to print out environment information.""" -+ -+ print(f"Platform: {platform.platform()}") -+ if platform.system() == "Linux": -+ print(f" libc: {'-'.join(platform.libc_ver())}") -+ -+ print(f"Python: {platform.python_version()}") -+ -+ print(f"PyTorch: {torch.__version__}") -+ print(f" CUDA: {torch.version.cuda or 'N/A'}") -+ print(f" HIP: {torch.version.hip or 'N/A'}") -+ print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") -+ -+ print("Related packages:") -+ for pkg in _RELATED_PACKAGES: -+ version = get_package_version(pkg) -+ print(f" {pkg}: {version}") -+ -+ - def main(): -- print_header("") -- print_header("BUG REPORT INFORMATION") -+ print_header(f"bitsandbytes v{bnb_version}") -+ show_environment() - print_header("") - -- print_header("OTHER") - cuda_specs = get_cuda_specs() -++<<<<<<< HEAD - + if HIP_ENVIRONMENT: - + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" - + print(f"{BNB_BACKEND} specs:{rocm_specs}") - + else: - + print(f"{BNB_BACKEND} specs:{cuda_specs}") - + if not torch.cuda.is_available(): - + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - + if not HIP_ENVIRONMENT: print(f"- {BNB_BACKEND} driver not installed") - + print(f"- {BNB_BACKEND} not installed") - + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") - + if cuda_specs: - + print_diagnostics(cuda_specs) - + print_runtime_diagnostics() - + print_header("") - + print_header("DEBUG INFO END") - + print_header("") - + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") - + try: - + sanity_check() - + print("SUCCESS!") - + print("Installation was successful!") - + return - + except RuntimeError as e: - + if "not available in CPU-only" in str(e): - + print( - + f"WARNING: {__package__} is currently running as CPU-only!\n" - + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - + f"If you think that this is so erroneously,\nplease report an issue!", - + ) - + else: - + raise e - + except Exception: - + traceback.print_exc() - + print_dedented( - + f""" - + Above we output some debug information. - + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose - + WARNING: Please be sure to sanitize sensitive info from the output before posting it. - + """, - + ) - + sys.exit(1) -++======= -+ -+ if cuda_specs: -+ print_cuda_diagnostics(cuda_specs) -+ -+ # TODO: There's a lot of noise in this; needs improvement. -+ # print_cuda_runtime_diagnostics() -+ -+ if not torch.cuda.is_available(): -+ print("PyTorch says CUDA is not available. Possible reasons:") -+ print("1. CUDA driver not installed") -+ print("2. Using a CPU-only PyTorch build") -+ print("3. No GPU detected") -+ -+ else: -+ print("Checking that the library is importable and CUDA is callable...") -+ -+ try: -+ sanity_check() -+ print("SUCCESS!") -+ return -+ except RuntimeError as e: -+ if "not available in CPU-only" in str(e): -+ print( -+ f"WARNING: {__package__} is currently running as CPU-only!\n" -+ "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" -+ f"If you think that this is so erroneously,\nplease report an issue!", -+ ) -+ else: -+ raise e -+ except Exception: -+ traceback.print_exc() -+ -+ print_dedented( -+ f""" -+ Above we output some debug information. -+ Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose -+ WARNING: Please be sure to sanitize sensitive info from the output before posting it. -+ """, -+ ) -+ sys.exit(1) -++>>>>>>> upstream/main -diff --cc bitsandbytes/functional.py -index 03f6c32,ffb6668..0000000 -mode 100644,100755..100755 ---- a/bitsandbytes/functional.py -+++ b/bitsandbytes/functional.py -@@@ -13,9 -13,9 +13,13 @@@ import torc - from torch import Tensor - from typing_extensions import deprecated - -- from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -+ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict - -++<<<<<<< HEAD - +from .cextension import lib, HIP_ENVIRONMENT -++======= -+ from .cextension import ipex_cpu, ipex_xpu, lib -++>>>>>>> upstream/main - - name2qmap = {} - -diff --cc bitsandbytes/nn/modules.py -index 2383f2c,ccd842c..0000000 ---- a/bitsandbytes/nn/modules.py -+++ b/bitsandbytes/nn/modules.py -@@@ -11,8 -11,7 +11,12 @@@ from torch import Tensor, device, dtype - import torch.nn.functional as F - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from bitsandbytes.functional import QuantState -++======= -+ from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu -++>>>>>>> upstream/main - from bitsandbytes.optim import GlobalOptimManager - from bitsandbytes.utils import ( - INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, -diff --cc tests/test_linear4bit.py -index 1b7a772,b5db2eb..0000000 ---- a/tests/test_linear4bit.py -+++ b/tests/test_linear4bit.py -@@@ -7,8 -8,14 +8,19 @@@ import pytes - import torch - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer -++======= -+ from tests.helpers import ( -+ TRUE_FALSE, -+ describe_dtype, -+ get_available_devices, -+ id_formatter, -+ torch_load_from_buffer, -+ torch_save_to_buffer, -+ ) -++>>>>>>> upstream/main - - storage = { - "uint8": torch.uint8, -@@@ -183,16 -185,10 +189,10 @@@ def test_linear_serialization(device, q - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_copy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -208,16 -204,10 +208,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -240,16 -230,10 +234,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) -+ original_tensor = torch.randn(300, 400) - original_param = bnb.nn.Params4bit( - data=original_tensor, - quant_type=quant_type, From e119ff73efa8aa4d48c651e2d762e5107631f22d Mon Sep 17 00:00:00 2001 From: amcamd Date: Thu, 5 Jun 2025 17:13:30 -0400 Subject: [PATCH 63/98] update for hipblasVersionMajor >=3 --- csrc/ops.hip | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/csrc/ops.hip b/csrc/ops.hip index eef616d48..a9c3e0202 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -269,6 +269,15 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in const void * beta = &fbeta; hipblasStatus_t status; +#if hipblasVersionMajor >= 3 + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, + C, HIP_R_32I, ldc, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else status = hipblasGemmEx(context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, @@ -276,6 +285,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, C, HIPBLAS_R_32I, ldc, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif if (status != HIPBLAS_STATUS_SUCCESS) { @@ -299,6 +309,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i //printf("%i %i %i\n", strideA, strideB, strideC); //printf("%i\n", batchCount); +#if hipblasVersionMajor >= 3 + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, + C, HIP_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else status = hipblasGemmStridedBatchedEx(context->m_handle, transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, @@ -306,6 +325,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif if (status != HIPBLAS_STATUS_SUCCESS) { From 8dc297d32adf90a079decd0a8649736dc5258089 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 6 Jun 2025 23:40:46 +0530 Subject: [PATCH 64/98] Update test_functional.py --- tests/test_functional.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_functional.py b/tests/test_functional.py index a2964c733..95f75d99f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -98,6 +98,9 @@ class Test8BitBlockwiseQuantizeFunctional: ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + iters = 100 if device == "cpu": @@ -150,6 +153,7 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, assert A2.dtype == dtype @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") + @pytest.mark.skipif(HIP_ENVIRONMENT, reason="CPU tests skipped when HIP_ENVIRONMENT is set") @pytest.mark.parametrize("hidden", [128]) @pytest.mark.parametrize("blocksize", [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): @@ -176,6 +180,9 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) def test_few_bit_quant(self, device, bits, method): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and bits != 8: pytest.skip("CPU implementation only supports 8 bits") @@ -232,6 +239,9 @@ def test_few_bit_quant(self, device, bits, method): @pytest.mark.parametrize("device", get_available_devices()) def test_fp8_quant(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + # TODO if device == "cpu": pytest.skip("CPU implementation segfaults") @@ -570,6 +580,9 @@ class TestLLMInt8Functional: @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + for i in range(k): if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim3), dtype=torch.int8, device=device) @@ -588,6 +601,9 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + for i in range(k): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device=device).half() @@ -611,6 +627,9 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + inner = 128 bias = None if has_bias: @@ -734,6 +753,9 @@ def test_int8_double_quant(self, dim1, dim2): ), ) def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and inner > 2048: pytest.skip("Slow on CPU") @@ -767,6 +789,9 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + threshold = 2.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -787,6 +812,9 @@ def test_coo_double_quant(self, device, dim1, dim2): @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -1115,6 +1143,9 @@ class TestQuantize4BitFunctional: [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1150,6 +1181,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1228,6 +1262,9 @@ def test_bench_4bit_dequant(self, quant_type): ) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1384,6 +1421,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double reason="this test is not supported on ROCm with gfx90a architecture yet", ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From f7d8bf340bb9d36c3412cbbedc564c2edecc8308 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 6 Jun 2025 23:45:28 +0530 Subject: [PATCH 65/98] Update test_linear4bit.py --- tests/test_linear4bit.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 60c163477..546ed2681 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -25,6 +25,9 @@ @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if quant_type == "fp4": pytest.xfail("FP4 is not supported for CPU") @@ -187,6 +190,9 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -212,6 +218,9 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -244,6 +253,9 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") From fd0a4d0fc4dc610fcf96a0469b41a68299d6daa5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 6 Jun 2025 23:52:40 +0530 Subject: [PATCH 66/98] Update test_ops.py --- tests/test_ops.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_ops.py b/tests/test_ops.py index a433a0c4b..3879aa479 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -11,6 +11,9 @@ class TestLLMInt8Ops: @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) @@ -23,6 +26,9 @@ def test_int8_linear_matmul(self, device): @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul_out(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -38,6 +44,9 @@ def test_int8_linear_matmul_out(self, device): @pytest.mark.parametrize("threshold", [0.0, 6.0]) @pytest.mark.parametrize("device", get_available_devices()) def test_int8_vectorwise_quant(self, threshold, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -64,6 +73,9 @@ def test_int8_vectorwise_quant(self, threshold, device): @pytest.mark.parametrize("device", get_available_devices()) def test_int8_mm_dequant(self, device): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) col_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -79,6 +91,9 @@ def test_int8_mm_dequant(self, device): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) def test_int8_scaled_mm(self, device, dtype, has_bias): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) row_stats = torch.randn(10, dtype=torch.float32, device=device) @@ -98,6 +113,9 @@ class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -122,6 +140,9 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -148,6 +169,9 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -172,6 +196,9 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": if quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -209,6 +236,9 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if HIP_ENVIRONMENT and device == "cpu": + pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") + if device == "cpu": pytest.xfail("CPU implementation is not available") From 75487d38f59e6b3c6e05182ecc42330275f488f6 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Sat, 7 Jun 2025 00:58:16 +0530 Subject: [PATCH 67/98] Update main.py --- bitsandbytes/diagnostics/main.py | 33 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 7cd04e209..ed7999f14 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -77,29 +77,25 @@ def main(): print_header("") cuda_specs = get_cuda_specs() - if HIP_ENVIRONMENT: - rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" - print(f"{BNB_BACKEND} specs:{rocm_specs}") - else: - print(f"{BNB_BACKEND} specs:{cuda_specs}") - if not torch.cuda.is_available(): - print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - if not HIP_ENVIRONMENT: - print(f"- {BNB_BACKEND} driver not installed") - print(f"- {BNB_BACKEND} not installed") - print(f"- You have multiple conflicting {BNB_BACKEND} libraries") + if cuda_specs: print_diagnostics(cuda_specs) - print_runtime_diagnostics() - print_header("") - print_header("DEBUG INFO END") - print_header("") - print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") + + # TODO: There's a lot of noise in this; needs improvement. + # print_cuda_runtime_diagnostics() + + if not torch.cuda.is_available(): + print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:") + print(f"1. {BNB_BACKEND} driver not installed") + print(f"2. Using a CPU-only PyTorch build") + print(f"3. No GPU detected") + + else: + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") + try: sanity_check() print("SUCCESS!") - print("Installation was successful!") return except RuntimeError as e: if "not available in CPU-only" in str(e): @@ -112,6 +108,7 @@ def main(): raise e except Exception: traceback.print_exc() + print_dedented( f""" Above we output some debug information. From 3551457f987e834999b39f6df01868587e3233e3 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:32:43 +0530 Subject: [PATCH 68/98] Update test_functional.py --- tests/test_functional.py | 105 +++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 53 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 95f75d99f..719f21137 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -89,7 +89,10 @@ def reset(self): class Test8BitBlockwiseQuantizeFunctional: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize( @@ -98,9 +101,6 @@ class Test8BitBlockwiseQuantizeFunctional: ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - iters = 100 if device == "cpu": @@ -153,7 +153,6 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, assert A2.dtype == dtype @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") - @pytest.mark.skipif(HIP_ENVIRONMENT, reason="CPU tests skipped when HIP_ENVIRONMENT is set") @pytest.mark.parametrize("hidden", [128]) @pytest.mark.parametrize("blocksize", [4096, 16384]) def test_blockwise_cpu_large(self, hidden, blocksize): @@ -176,13 +175,13 @@ def test_blockwise_cpu_large(self, hidden, blocksize): # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) def test_few_bit_quant(self, device, bits, method): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and bits != 8: pytest.skip("CPU implementation only supports 8 bits") @@ -237,11 +236,11 @@ def test_few_bit_quant(self, device, bits, method): else: torch.testing.assert_close(q1, q2) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_fp8_quant(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - # TODO if device == "cpu": pytest.skip("CPU implementation segfaults") @@ -572,7 +571,10 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @@ -580,9 +582,6 @@ class TestLLMInt8Functional: @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - for i in range(k): if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim3), dtype=torch.int8, device=device) @@ -594,16 +593,16 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - for i in range(k): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device=device).half() @@ -621,15 +620,15 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - inner = 128 bias = None if has_bias: @@ -740,7 +739,10 @@ def test_int8_double_quant(self, dim1, dim2): torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -753,9 +755,6 @@ def test_int8_double_quant(self, dim1, dim2): ), ) def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and inner > 2048: pytest.skip("Slow on CPU") @@ -785,13 +784,13 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): err2 = torch.abs(out1 - out3).mean().item() assert err2 <= err1 * 1.025 - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - threshold = 2.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -808,13 +807,13 @@ def test_coo_double_quant(self, device, dim1, dim2): A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device=device).half() @@ -1135,7 +1134,10 @@ def test_coo2csc(self): class TestQuantize4BitFunctional: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( @@ -1143,9 +1145,6 @@ class TestQuantize4BitFunctional: [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1177,13 +1176,13 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 assert err.item() < math.log2(blocksize) * 8e-2 - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and quant_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1250,7 +1249,10 @@ def test_bench_4bit_dequant(self, quant_type): @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @@ -1262,9 +1264,6 @@ def test_bench_4bit_dequant(self, quant_type): ) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") @@ -1412,7 +1411,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @@ -1421,9 +1423,6 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double reason="this test is not supported on ROCm with gfx90a architecture yet", ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and storage_type != "nf4": pytest.xfail("fp4 quantization is not supported on CPU") From 90437b94837529b7519e59c64a5d5774090fba80 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:44:33 +0530 Subject: [PATCH 69/98] Update test_linear4bit.py --- tests/test_linear4bit.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 546ed2681..fe3f4b13c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -18,7 +18,10 @@ } -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -185,7 +188,10 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua assert size_ratio < target_compression, ratio_error_msg -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -213,7 +219,10 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -248,7 +257,10 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): assert dict_keys_before == dict_keys_copy -@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) From a0bdc94db673238c0b3e12ff9ac03117f5f966f2 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:51:33 +0530 Subject: [PATCH 70/98] Update test_ops.py --- tests/test_ops.py | 80 +++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 3879aa479..e3be5fd50 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -9,11 +9,11 @@ class TestLLMInt8Ops: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) @@ -24,11 +24,11 @@ def test_int8_linear_matmul(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul_out(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -42,11 +42,11 @@ def test_int8_linear_matmul_out(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) @pytest.mark.parametrize("threshold", [0.0, 6.0]) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_vectorwise_quant(self, threshold, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -71,11 +71,11 @@ def test_int8_vectorwise_quant(self, threshold, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_mm_dequant(self, device): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) col_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -87,13 +87,13 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) def test_int8_scaled_mm(self, device, dtype, has_bias): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) row_stats = torch.randn(10, dtype=torch.float32, device=device) @@ -109,13 +109,13 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -136,13 +136,13 @@ def test_quantize_blockwise(self, device, dtype, blocksize): torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -163,15 +163,15 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): class Test4bitBlockwiseQuantOps: - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu" and quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -190,15 +190,15 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if quant_type != "nf4": pytest.xfail("CPU implementation is only available for nf4") @@ -230,15 +230,15 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) ) - @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": pytest.xfail("CPU implementation is not available") From 8a27346f8fd6ecf8eea4127ea13899618d9a921c Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 02:56:08 +0530 Subject: [PATCH 71/98] Update test_linear4bit.py --- tests/test_linear4bit.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index fe3f4b13c..760e4b8c9 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -21,16 +21,13 @@ @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if quant_type == "fp4": pytest.xfail("FP4 is not supported for CPU") @@ -191,14 +188,11 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -222,14 +216,11 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") @@ -260,14 +251,11 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize( "device", [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): - if HIP_ENVIRONMENT and device == "cpu": - pytest.skip("CPU tests skipped when HIP_ENVIRONMENT is set") - if device == "cpu": if compress_statistics: pytest.skip("Currently segfaults on CPU") From c945dbb5c8b14bf54631d39f51b6c1d841981043 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 11 Jun 2025 03:05:38 +0530 Subject: [PATCH 72/98] Lint --- tests/test_functional.py | 70 ++++++++++++++++++------------------ tests/test_linear4bit.py | 26 +++++++------- tests/test_ops.py | 78 ++++++++++++++++++++-------------------- 3 files changed, 87 insertions(+), 87 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 719f21137..571eea55f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -236,9 +236,9 @@ def test_few_bit_quant(self, device, bits, method): else: torch.testing.assert_close(q1, q2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) def test_fp8_quant(self, device): # TODO @@ -571,9 +571,9 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @@ -593,9 +593,9 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @@ -620,9 +620,9 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @@ -739,10 +739,10 @@ def test_int8_double_quant(self, dim1, dim2): torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -784,10 +784,10 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): err2 = torch.abs(out1 - out3).mean().item() assert err2 <= err1 * 1.025 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): @@ -807,9 +807,9 @@ def test_coo_double_quant(self, device, dim1, dim2): A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) @@ -1134,9 +1134,9 @@ def test_coo2csc(self): class TestQuantize4BitFunctional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -1176,9 +1176,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 assert err.item() < math.log2(blocksize) * 8e-2 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) @@ -1249,9 +1249,9 @@ def test_bench_4bit_dequant(self, quant_type): @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @@ -1411,9 +1411,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 760e4b8c9..ddc609616 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -18,9 +18,9 @@ } -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @@ -185,10 +185,10 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua assert size_ratio < target_compression, ratio_error_msg -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -213,9 +213,9 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @@ -248,9 +248,9 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): assert dict_keys_before == dict_keys_copy -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], +@pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) diff --git a/tests/test_ops.py b/tests/test_ops.py index e3be5fd50..9d406b793 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -9,10 +9,10 @@ class TestLLMInt8Ops: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -24,10 +24,10 @@ def test_int8_linear_matmul(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_linear_matmul_out(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -42,10 +42,10 @@ def test_int8_linear_matmul_out(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) @pytest.mark.parametrize("threshold", [0.0, 6.0]) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_vectorwise_quant(self, threshold, device): A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -71,10 +71,10 @@ def test_int8_vectorwise_quant(self, threshold, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) def test_int8_mm_dequant(self, device): A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -87,9 +87,9 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) @@ -109,10 +109,10 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): @@ -136,10 +136,10 @@ def test_quantize_blockwise(self, device, dtype, blocksize): torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): @@ -163,10 +163,10 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): class Test4bitBlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -190,10 +190,10 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -230,10 +230,10 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize( + "device", + [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], + ) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) From 58e989ef989852e98ac11540d1e0b144b4f68783 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 11 Jun 2025 11:43:45 +0530 Subject: [PATCH 73/98] Lint --- bitsandbytes/diagnostics/main.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index ed7999f14..9a0447433 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -6,12 +6,11 @@ import torch from bitsandbytes import __version__ as bnb_version -from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT +from bitsandbytes.cextension import BNB_BACKEND from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.diagnostics.cuda import ( print_diagnostics, - print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header @@ -77,18 +76,18 @@ def main(): print_header("") cuda_specs = get_cuda_specs() - + if cuda_specs: print_diagnostics(cuda_specs) # TODO: There's a lot of noise in this; needs improvement. # print_cuda_runtime_diagnostics() - + if not torch.cuda.is_available(): print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:") print(f"1. {BNB_BACKEND} driver not installed") - print(f"2. Using a CPU-only PyTorch build") - print(f"3. No GPU detected") + print("2. Using a CPU-only PyTorch build") + print("3. No GPU detected") else: print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") @@ -108,7 +107,7 @@ def main(): raise e except Exception: traceback.print_exc() - + print_dedented( f""" Above we output some debug information. From 2cce3366b363de9499220a00a62e34e88183ced4 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:03:19 +0530 Subject: [PATCH 74/98] Update helpers.py --- tests/helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index fbc4af071..671ea39eb 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,8 @@ import torch +from bitsandbytes.cextension import HIP_ENVIRONMENT + test_dims_rng = random.Random(42) @@ -21,7 +23,7 @@ def get_available_devices(): # If the environment variable is set, use it directly. return [os.environ["BNB_TEST_DEVICE"]] - devices = ["cpu"] + devices = [] if HIP_ENVIRONMENT else ["cpu"] if hasattr(torch, "accelerator"): # PyTorch 2.6+ - determine accelerator using agnostic API. From 5eb0316802d87c25e0c850a13c7cec77e9648583 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:15:09 +0530 Subject: [PATCH 75/98] Update test_functional.py --- tests/test_functional.py | 65 ++++++++-------------------------------- 1 file changed, 13 insertions(+), 52 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 571eea55f..a2964c733 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -89,10 +89,7 @@ def reset(self): class Test8BitBlockwiseQuantizeFunctional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize( @@ -175,10 +172,7 @@ def test_blockwise_cpu_large(self, hidden, blocksize): # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) def test_few_bit_quant(self, device, bits, method): @@ -236,10 +230,7 @@ def test_few_bit_quant(self, device, bits, method): else: torch.testing.assert_close(q1, q2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_fp8_quant(self, device): # TODO if device == "cpu": @@ -571,10 +562,7 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) @@ -593,10 +581,7 @@ def test_int8_linear_matmul(self, device, dim1, dim2, dim3, dim4, dims, ldb): C2 = F.int8_linear_matmul(A, B) torch.testing.assert_close(C1, C2.float()) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) @@ -620,10 +605,7 @@ def test_int8_linear_matmul_half(self, device, dim1, dim2, dim3, dim4, dims): torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) @@ -739,10 +721,7 @@ def test_int8_double_quant(self, dim1, dim2): torch.testing.assert_close(Srow.flatten().float(), statsA) torch.testing.assert_close(Scol.flatten().float(), statsAt) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize( ("dim1", "dim4", "inner"), ( @@ -784,10 +763,7 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): err2 = torch.abs(out1 - out3).mean().item() assert err2 <= err1 * 1.025 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_double_quant(self, device, dim1, dim2): @@ -807,10 +783,7 @@ def test_coo_double_quant(self, device, dim1, dim2): A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): @@ -1134,10 +1107,7 @@ def test_coo2csc(self): class TestQuantize4BitFunctional: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( @@ -1176,10 +1146,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 assert err.item() < math.log2(blocksize) * 8e-2 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): @@ -1249,10 +1216,7 @@ def test_bench_4bit_dequant(self, quant_type): @pytest.mark.skipif( HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) @@ -1411,10 +1375,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double assert relratio < 1.04 and relratio > 0.96 assert maxratio < 1.02 and maxratio > 0.98 - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) From dcdf2c54ffe023295a1b6f60edab18d60b073552 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:15:41 +0530 Subject: [PATCH 76/98] Update test_linear4bit.py --- tests/test_linear4bit.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index ddc609616..60c163477 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -18,10 +18,7 @@ } -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_storage", ["uint8", "float16", "bfloat16", "float32"]) @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -185,10 +182,7 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua assert size_ratio < target_compression, ratio_error_msg -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -213,10 +207,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): assert param.data.data_ptr() == shallow_copy_param.data.data_ptr() -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @@ -248,10 +239,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): assert dict_keys_before == dict_keys_copy -@pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], -) +@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) From 6bba74052813946d28b238755d43756ff0e6c4f5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:16:22 +0530 Subject: [PATCH 77/98] Update test_ops.py --- tests/test_ops.py | 50 ++++++++++------------------------------------- 1 file changed, 10 insertions(+), 40 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 9d406b793..a433a0c4b 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -9,10 +9,7 @@ class TestLLMInt8Ops: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -24,10 +21,7 @@ def test_int8_linear_matmul(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_linear_matmul_out(self, device): A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) @@ -42,10 +36,7 @@ def test_int8_linear_matmul_out(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) @pytest.mark.parametrize("threshold", [0.0, 6.0]) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_vectorwise_quant(self, threshold, device): A = torch.randn(10, 20, dtype=torch.float16, device=device) A[1][0] = 1000.0 @@ -71,10 +62,7 @@ def test_int8_vectorwise_quant(self, threshold, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) def test_int8_mm_dequant(self, device): A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) row_stats = torch.randn(256, dtype=torch.float32, device=device) @@ -87,10 +75,7 @@ def test_int8_mm_dequant(self, device): torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("has_bias", TRUE_FALSE) def test_int8_scaled_mm(self, device, dtype, has_bias): @@ -109,10 +94,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): @@ -136,10 +118,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): @@ -163,10 +142,7 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): class Test4bitBlockwiseQuantOps: - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -190,10 +166,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @@ -230,10 +203,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) ) - @pytest.mark.parametrize( - "device", - [d for d in get_available_devices() if not (HIP_ENVIRONMENT and d == "cpu")], - ) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) From bdd67545ed1c66d1fc7cf5b84118b6bba107755e Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 11 Jun 2025 15:18:19 +0530 Subject: [PATCH 78/98] Lint --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 671ea39eb..54eec95dc 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,7 @@ import torch -from bitsandbytes.cextension import HIP_ENVIRONMENT +from bitsandbytes.cextension import HIP_ENVIRONMENT test_dims_rng = random.Random(42) From 3db3196e18ac46496dbc60569b0efd1d603b1f53 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 18 Jun 2025 07:08:21 +0530 Subject: [PATCH 79/98] Update pythonInterface.cpp --- csrc/pythonInterface.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 66e96b07f..a8d47b8de 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -315,12 +315,6 @@ void spmm_coo_very_sparse_naive_int8( extern "C" { #if BUILD_CUDA || BUILD_HIP -void cestimate_quantiles_fp32(float* A, float* code, float offset, int n) { - estimateQuantiles_fp32(A, code, offset, n); -} - -void cestimate_quantiles_fp16(half* A, float* code, float offset, int n) { estimateQuantiles_fp16(A, code, offset, n); } - void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { From 75a654e3e1eacb6ba78b98bc153925377e530bd8 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 18 Jun 2025 07:11:20 +0530 Subject: [PATCH 80/98] lint fix --- csrc/common_hip.cuh | 2 +- csrc/kernels_hip.cuh | 236 +++++++++++++++++---------------- csrc/ops_hip.cuh | 302 +++++++++++++++++++++++-------------------- 3 files changed, 288 insertions(+), 252 deletions(-) diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh index 105179535..1d9d9afe0 100644 --- a/csrc/common_hip.cuh +++ b/csrc/common_hip.cuh @@ -1,6 +1,6 @@ #pragma once -#define BNB_WARP_SIZE warpSize +#define BNB_WARP_SIZE warpSize // These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs #define BNB_MAX_THREADS_PER_SM 2048 diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index 2895012f8..811299d05 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -11,122 +11,136 @@ #ifndef kernels #define kernels - -template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); - -__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); -__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); - -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); - -template -__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -__global__ void kOptimizer32bit2State(T* g, T* p, - float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float beta3, const float alpha, - const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); - -template -__global__ void kOptimizer32bit1State(T* g, T* p, - float* state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); - -template -__global__ void -kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - const float weight_decay, - const float gnorm_scale, const int n); - - -template +template __global__ void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* max1, float* new_max1, - float weight_decay, const float gnorm_scale, const int n); - + kEstimateQuantiles(T* __restrict__ const A, float* code, const float offset, const T max_val, const int n); +__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); +__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); -template +template +__global__ void kQuantizeBlockwise( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +); +template __global__ void -kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, - float *unorm, - const float beta1, const float beta2, - const float eps, const int step, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - const float gnorm_scale, const int n); - + kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); + +template +__global__ void kPreconditionOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, + const float beta2, const float eps, const float weight_decay, const int step, const float lr, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit1State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, + const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, + float* new_max1, const float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit1State( + T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, + const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit2State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, + float* unorm, const float beta1, const float beta2, const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2State( + T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, + const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, + const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); -template __global__ void -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, - const float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, const float gnorm_scale, const int n); - -template __global__ void kOptimizerStatic8bit2StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, - const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); - -template __global__ void kOptimizerStatic8bit1StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, - float* absmax1, - float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n); - - -template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); - - -__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); - - -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); - -template __global__ void kdequant_mm_int32_fp16( - int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); - -template __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols); -template __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); - -template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); - -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); -template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); - -template __global__ void kfunc(T *A, T *B, T value, long n); + kHistogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, const int maxidx1, const int n); + +template +__global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n +); + +template +__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols); +template +__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); + +template +__global__ void kTransformRowToFormat( + char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols +); + +template +__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); +template +__global__ void kgemm_4bit_inference( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, + int blocksize +); +template +__global__ void kgemm_4bit_inference_naive( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, + int lda, int ldb, int ldc, int blocksize +); + +template __global__ void kfunc(T* A, T* B, T value, long n); #endif diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index bcfc73e99..624ebe326 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -4,42 +4,42 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - #ifndef ops_H #define ops_H +#include #include -#include #include +#include #include -#include -#include +#include #include -#include +#include #include #include +#include #include -#include - -#define CUDA_CHECK_RETURN(value) { \ - hipError_t _m_cudaStat = value; \ - if (_m_cudaStat != hipSuccess) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", \ - hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ - exit(1); \ - } } - - -#define CHECK_HIPSPARSE(value) { \ - hipsparseStatus_t _m_hipStat = value; \ - if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", \ - hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__); \ - exit(1); \ - } } +#define CUDA_CHECK_RETURN(value) \ + { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } +#define CHECK_HIPSPARSE(value) \ + { \ + hipsparseStatus_t _m_hipStat = value; \ + if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf( \ + stderr, "Error %s at line %d in file %s\n", hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__ \ + ); \ + exit(1); \ + } \ + } inline void checkHipStatus(hipError_t status) { if (status != hipSuccess) { @@ -51,145 +51,167 @@ inline void checkHipStatus(hipError_t status) { inline int checkHipblasStatus(hipblasStatus_t status) { if (status != HIPBLAS_STATUS_SUCCESS) { printf("hipBLAS API failed with status %d\n", status); - //throw std::logic_error("cuBLAS API failed"); + // throw std::logic_error("cuBLAS API failed"); return 1; } return 0; } -typedef enum Operations_t -{ - ksmul = 0, +typedef enum Operations_t { + ksmul = 0, } Operations_t; -typedef enum Optimizer_t -{ - ADAM = 0, - MOMENTUM = 1, - RMSPROP = 2, - LARS = 3, - ADAGRAD = 4, - LION = 5, - ADEMAMIX = 6, +typedef enum Optimizer_t { + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6, } Optimizer_t; -typedef enum Transform_t -{ - ROW = 0, - COL = 1, - COL32 = 2, - COL_TURING = 3, - COL_AMPERE = 4, +typedef enum Transform_t { + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, } Transform_t; -typedef enum DataType_t -{ - General8bit = 0, - FP4 = 1, - NF4 = 2, +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, } DataType_t; -typedef enum Funcs_t -{ - FILL = 0, - ARANGE = 1, - _MUL = 2, +typedef enum Funcs_t { + FILL = 0, + ARANGE = 1, + _MUL = 2, } Funcs_t; -class Context -{ - public: - rocblas_handle m_handle; - - Context() - { - rocblas_handle handle; - rocblas_create_handle(&handle); - m_handle = handle; - } - -}; +class Context { + public: + rocblas_handle m_handle; -class ContextLt -{ - public: - hipblasLtHandle_t m_handle; - - ContextLt() - { - hipblasLtHandle_t handle; - hipblasLtCreate(&handle); - m_handle = handle; - } + Context() { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } }; -class ContextHipsparse -{ - public: - hipsparseHandle_t m_handle; - - ContextHipsparse() - { - hipsparseHandle_t handle; - hipsparseCreate(&handle); - m_handle = handle; - } +class ContextLt { + public: + hipblasLtHandle_t m_handle; + ContextLt() { + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); + m_handle = handle; + } }; +class ContextHipsparse { + public: + hipsparseHandle_t m_handle; -template void estimateQuantiles(T *A, float *code, float offset, int n); - -void quantize(float *code, float *A, unsigned char *out, int n); -void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, hipStream_t stream); - -template void optimizer32bit(T* g, T* p, - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, - int step, float lr, const float gnorm_scale, bool skip_zeros, int n); - -template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, - float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, - float eps, int step, float lr, - float* quantiles1, float* quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, - const float gnorm_scale, int n); - -template void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, - bool skip_zeros, int n); - -template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); - -void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); - -void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, - long long int strideA, long long int strideB, long long int strideC, int batchCount); - - -template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); - -void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, hipStream_t stream); -void getRowStats(half * A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); -void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream); - -void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); - -template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); - -void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); - -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); -template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + ContextHipsparse() { + hipsparseHandle_t handle; + hipsparseCreate(&handle); + m_handle = handle; + } +}; -template void func(T *A, T *B, T value, long n); +template void estimateQuantiles(T* A, float* code, float offset, int n); + +void quantize(float* code, float* A, unsigned char* out, int n); +void dequantize(float* code, unsigned char* A, float* out, int n, hipStream_t stream); +template +void quantizeBlockwise( + float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, hipStream_t stream +); + +template +void optimizer32bit( + T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, + float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, + bool skip_zeros, int n +); + +template +void optimizerStatic8bit( + T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, + float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n +); + +template +void optimizerStatic8bitBlockwise( + T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, bool skip_zeros, int n +); + +template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, int maxidx1, int n); + +void gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +); +void strided_gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount +); + +template +int igemmlt( + hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, hipStream_t stream +); + +void cutlass_igemm( + bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc +); +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream +); +void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, hipStream_t stream); +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream +); + +void spmm_coo( + hipsparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, half* B, int ldc, half* C, bool transposed_B +); + +template +void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); + +void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits); +template +void gemm_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize +); +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, hipStream_t stream +); + +template void func(T* A, T* B, T value, long n); #endif From 562473620ee321fd3d74ed42b06775f245b35282 Mon Sep 17 00:00:00 2001 From: MISHANMAUYRA Date: Wed, 18 Jun 2025 14:58:01 +0530 Subject: [PATCH 81/98] lint --- conflicts.diff | 431 ------------------------------------------------- 1 file changed, 431 deletions(-) delete mode 100644 conflicts.diff diff --git a/conflicts.diff b/conflicts.diff deleted file mode 100644 index bab359251..000000000 --- a/conflicts.diff +++ /dev/null @@ -1,431 +0,0 @@ -diff --cc .github/workflows/python-package.yml -index 3673ac6,d3deb26..0000000 ---- a/.github/workflows/python-package.yml -+++ b/.github/workflows/python-package.yml -@@@ -218,7 -173,14 +218,18 @@@ jobs - merge-multiple: true - - - name: Inspect tmp directory after downloading artifacts -++<<<<<<< HEAD - + run: ls -alFR tmp/ -++======= -+ run: | -+ ls -alFR tmp/ -+ WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l) -+ echo "Found $WHEEL_COUNT wheel files" -+ if [ "$WHEEL_COUNT" -eq 0 ]; then -+ echo "::error::No wheel files found in tmp directory! Cannot proceed with release." -+ exit 1 -+ fi -++>>>>>>> upstream/main - - - name: Move and rename wheel files with pattern replacement - run: | -@@@ -245,9 -207,20 +256,23 @@@ - - name: Inspect wheels directory after renaming files - run: ls -alFR wheels/ - -++<<<<<<< HEAD -++======= -+ - uses: actions/checkout@v4 -+ with: -+ path: repo -++>>>>>>> upstream/main - - name: Delete old pre-release (if exists) - run: | -- gh release delete continuous-release_main --cleanup-tag -y || true -+ cd repo && gh release delete continuous-release_main --cleanup-tag -y -+ env: -+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -+ -+ - name: Ensure tag exists -+ run: | -+ cd repo -+ git tag -f continuous-release_main -+ git push -f origin continuous-release_main - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - -diff --cc bitsandbytes/cextension.py -index 5283df9,b112df2..0000000 ---- a/bitsandbytes/cextension.py -+++ b/bitsandbytes/cextension.py -@@@ -28,17 -28,10 +29,15 @@@ def get_cuda_bnb_library_path(cuda_spec - override_value = os.environ.get("BNB_CUDA_VERSION") - if override_value: - library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) - + if torch.version.hip: - + raise RuntimeError( - + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" - + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" - + ) - logger.warning( - f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" -- "This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n" -+ "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" - "If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n" -- "If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n" -- "For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: BNBNativeLi - return BNBNativeLibrary(dll) - - - +ROCM_GPU_ARCH = get_rocm_gpu_arch() - + - try: -++<<<<<<< HEAD - + if torch.version.hip: - + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" - + else: - + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" - + -++======= -+ # to support Intel CPU/GPU (XPU) backend -+ import intel_extension_for_pytorch as ipex -+ -+ ipex_cpu = ipex if ipex._C._has_cpu() else None -+ ipex_xpu = ipex if ipex._C._has_xpu() else None -+ except BaseException: -+ ipex_cpu = None -+ ipex_xpu = None -+ -+ -+ try: -++>>>>>>> upstream/main - lib = get_native_library() - except Exception as e: - error_msg = str(e) -diff --cc bitsandbytes/diagnostics/cuda.py -index b9db101,e763ef2..0000000 ---- a/bitsandbytes/diagnostics/cuda.py -+++ b/bitsandbytes/diagnostics/cuda.py -@@@ -5,8 -5,7 +5,12 @@@ from pathlib import Pat - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path - +from bitsandbytes.consts import NONPYTORCH_DOC_URL -++======= -+ from bitsandbytes.cextension import get_cuda_bnb_library_path -++>>>>>>> upstream/main - from bitsandbytes.cuda_specs import CUDASpecs - from bitsandbytes.diagnostics.utils import print_dedented - -@@@ -148,42 -127,8 +136,38 @@@ def _print_cuda_diagnostics(cuda_specs - """, - ) - -- # TODO: -- # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) -- # (2) Multiple CUDA versions installed -- - - -def print_cuda_runtime_diagnostics() -> None: - +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: - + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") - + - + binary_path = get_cuda_bnb_library_path(cuda_specs) - + if not binary_path.exists(): - + print_dedented( - + f""" - + Library not found: {binary_path}. - + Maybe you need to compile it from source? If you compiled from source, check that ROCm version - + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version - + and rebuild bitsandbytes. - + """, - + ) - + - + hip_major, hip_minor = cuda_specs.cuda_version_tuple - + if (hip_major, hip_minor) < (6, 1): - + print_dedented( - + """ - + WARNING: bitsandbytes is fully supported only from ROCm 6.1. - + """, - + ) - + - + - +def print_diagnostics(cuda_specs: CUDASpecs) -> None: - + if HIP_ENVIRONMENT: - + _print_hip_diagnostics(cuda_specs) - + else: - + _print_cuda_diagnostics(cuda_specs) - + - + - +def _print_cuda_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: - print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") -diff --cc bitsandbytes/diagnostics/main.py -index bf31d79,aa4cb30..0000000 ---- a/bitsandbytes/diagnostics/main.py -+++ b/bitsandbytes/diagnostics/main.py -@@@ -3,12 -5,11 +5,20 @@@ import tracebac - - import torch - -++<<<<<<< HEAD - +from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT - +from bitsandbytes.consts import PACKAGE_GITHUB_URL - +from bitsandbytes.cuda_specs import get_cuda_specs - +from bitsandbytes.diagnostics.cuda import ( - + print_diagnostics, - + print_runtime_diagnostics, -++======= -+ from bitsandbytes import __version__ as bnb_version -+ from bitsandbytes.consts import PACKAGE_GITHUB_URL -+ from bitsandbytes.cuda_specs import get_cuda_specs -+ from bitsandbytes.diagnostics.cuda import ( -+ print_cuda_diagnostics, -++>>>>>>> upstream/main - ) - from bitsandbytes.diagnostics.utils import print_dedented, print_header - -@@@ -28,53 -41,77 +50,123 @@@ def sanity_check() - assert p1 != p2 - - -+ def get_package_version(name: str) -> str: -+ try: -+ version = importlib.metadata.version(name) -+ except importlib.metadata.PackageNotFoundError: -+ version = "not found" -+ return version -+ -+ -+ def show_environment(): -+ """Simple utility to print out environment information.""" -+ -+ print(f"Platform: {platform.platform()}") -+ if platform.system() == "Linux": -+ print(f" libc: {'-'.join(platform.libc_ver())}") -+ -+ print(f"Python: {platform.python_version()}") -+ -+ print(f"PyTorch: {torch.__version__}") -+ print(f" CUDA: {torch.version.cuda or 'N/A'}") -+ print(f" HIP: {torch.version.hip or 'N/A'}") -+ print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}") -+ -+ print("Related packages:") -+ for pkg in _RELATED_PACKAGES: -+ version = get_package_version(pkg) -+ print(f" {pkg}: {version}") -+ -+ - def main(): -- print_header("") -- print_header("BUG REPORT INFORMATION") -+ print_header(f"bitsandbytes v{bnb_version}") -+ show_environment() - print_header("") - -- print_header("OTHER") - cuda_specs = get_cuda_specs() -++<<<<<<< HEAD - + if HIP_ENVIRONMENT: - + rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - + rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" - + print(f"{BNB_BACKEND} specs:{rocm_specs}") - + else: - + print(f"{BNB_BACKEND} specs:{cuda_specs}") - + if not torch.cuda.is_available(): - + print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") - + if not HIP_ENVIRONMENT: - + print(f"- {BNB_BACKEND} driver not installed") - + print(f"- {BNB_BACKEND} not installed") - + print(f"- You have multiple conflicting {BNB_BACKEND} libraries") - + if cuda_specs: - + print_diagnostics(cuda_specs) - + print_runtime_diagnostics() - + print_header("") - + print_header("DEBUG INFO END") - + print_header("") - + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") - + try: - + sanity_check() - + print("SUCCESS!") - + print("Installation was successful!") - + return - + except RuntimeError as e: - + if "not available in CPU-only" in str(e): - + print( - + f"WARNING: {__package__} is currently running as CPU-only!\n" - + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - + f"If you think that this is so erroneously,\nplease report an issue!", - + ) - + else: - + raise e - + except Exception: - + traceback.print_exc() - + print_dedented( - + f""" - + Above we output some debug information. - + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose - + WARNING: Please be sure to sanitize sensitive info from the output before posting it. - + """, - + ) - + sys.exit(1) -++======= -+ -+ if cuda_specs: -+ print_cuda_diagnostics(cuda_specs) -+ -+ # TODO: There's a lot of noise in this; needs improvement. -+ # print_cuda_runtime_diagnostics() -+ -+ if not torch.cuda.is_available(): -+ print("PyTorch says CUDA is not available. Possible reasons:") -+ print("1. CUDA driver not installed") -+ print("2. Using a CPU-only PyTorch build") -+ print("3. No GPU detected") -+ -+ else: -+ print("Checking that the library is importable and CUDA is callable...") -+ -+ try: -+ sanity_check() -+ print("SUCCESS!") -+ return -+ except RuntimeError as e: -+ if "not available in CPU-only" in str(e): -+ print( -+ f"WARNING: {__package__} is currently running as CPU-only!\n" -+ "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" -+ f"If you think that this is so erroneously,\nplease report an issue!", -+ ) -+ else: -+ raise e -+ except Exception: -+ traceback.print_exc() -+ -+ print_dedented( -+ f""" -+ Above we output some debug information. -+ Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose -+ WARNING: Please be sure to sanitize sensitive info from the output before posting it. -+ """, -+ ) -+ sys.exit(1) -++>>>>>>> upstream/main -diff --cc bitsandbytes/functional.py -index 9b7ce2d,ffb6668..0000000 -mode 100644,100755..100755 ---- a/bitsandbytes/functional.py -+++ b/bitsandbytes/functional.py -@@@ -13,9 -13,9 +13,13 @@@ import torc - from torch import Tensor - from typing_extensions import deprecated - -- from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -+ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict - -++<<<<<<< HEAD - +from .cextension import HIP_ENVIRONMENT, lib -++======= -+ from .cextension import ipex_cpu, ipex_xpu, lib -++>>>>>>> upstream/main - - name2qmap = {} - -diff --cc bitsandbytes/nn/modules.py -index a2facac,e349cc8..0000000 ---- a/bitsandbytes/nn/modules.py -+++ b/bitsandbytes/nn/modules.py -@@@ -11,8 -11,7 +11,12 @@@ from torch import Tensor, device, dtype - import torch.nn.functional as F - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from bitsandbytes.functional import QuantState -++======= -+ from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu -++>>>>>>> upstream/main - from bitsandbytes.optim import GlobalOptimManager - from bitsandbytes.utils import ( - INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, -diff --cc tests/test_linear4bit.py -index 60c1634,b5db2eb..0000000 ---- a/tests/test_linear4bit.py -+++ b/tests/test_linear4bit.py -@@@ -7,8 -8,14 +8,19 @@@ import pytes - import torch - - import bitsandbytes as bnb -++<<<<<<< HEAD - +from bitsandbytes.cextension import HIP_ENVIRONMENT - +from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer -++======= -+ from tests.helpers import ( -+ TRUE_FALSE, -+ describe_dtype, -+ get_available_devices, -+ id_formatter, -+ torch_load_from_buffer, -+ torch_save_to_buffer, -+ ) -++>>>>>>> upstream/main - - storage = { - "uint8": torch.uint8, -@@@ -184,16 -185,10 +190,10 @@@ def test_linear_serialization(device, q - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_copy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -209,16 -204,10 +209,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- tensor = torch.linspace(1, blocksize, blocksize) -+ tensor = torch.randn(300, 400) - param = bnb.nn.Params4bit( - data=tensor, - quant_type=quant_type, -@@@ -241,16 -230,10 +235,10 @@@ - - @pytest.mark.parametrize("device", get_available_devices()) - @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) - -@pytest.mark.parametrize("blocksize", [64, 128]) - +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) - @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) - def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): -- if device == "cpu": -- if compress_statistics: -- pytest.skip("Currently segfaults on CPU") -- if quant_type == "fp4": -- pytest.xfail("FP4 not supported on CPU") -- -- original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) -+ original_tensor = torch.randn(300, 400) - original_param = bnb.nn.Params4bit( - data=original_tensor, - quant_type=quant_type, From c75fdb7d52feb7d4b11a0e1141b91c50a1c04c4e Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Wed, 18 Jun 2025 15:02:59 +0530 Subject: [PATCH 82/98] Update pythonInterface.cpp --- csrc/pythonInterface.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index a8d47b8de..9c4cab9cc 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -37,11 +37,6 @@ //=================================================================================== #if BUILD_CUDA || BUILD_HIP -void estimateQuantiles_fp32(float* A, float* code, float offset, int n) { - estimateQuantiles(A, code, offset, n); -} - -void estimateQuantiles_fp16(half* A, float* code, float offset, int n) { estimateQuantiles(A, code, offset, n); } // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } From 3936ca40bffa149bb871b753e5536dcd3ab96817 Mon Sep 17 00:00:00 2001 From: Prasanth Nunna Date: Wed, 18 Jun 2025 12:09:27 -0500 Subject: [PATCH 83/98] revert permissions change --- bitsandbytes/functional.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 bitsandbytes/functional.py diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py old mode 100755 new mode 100644 From b4fd5942b07d65a2084656bc79221caab4d7f3fa Mon Sep 17 00:00:00 2001 From: pnunna93 <104791500+pnunna93@users.noreply.github.com> Date: Wed, 18 Jun 2025 12:31:24 -0500 Subject: [PATCH 84/98] Fix indentation --- bitsandbytes/diagnostics/main.py | 48 ++++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 9a0447433..74da662b6 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -92,27 +92,27 @@ def main(): else: print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") - try: - sanity_check() - print("SUCCESS!") - return - except RuntimeError as e: - if "not available in CPU-only" in str(e): - print( - f"WARNING: {__package__} is currently running as CPU-only!\n" - "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" - f"If you think that this is so erroneously,\nplease report an issue!", - ) - else: - raise e - except Exception: - traceback.print_exc() - - print_dedented( - f""" - Above we output some debug information. - Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose - WARNING: Please be sure to sanitize sensitive info from the output before posting it. - """, - ) - sys.exit(1) + try: + sanity_check() + print("SUCCESS!") + return + except RuntimeError as e: + if "not available in CPU-only" in str(e): + print( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!", + ) + else: + raise e + except Exception: + traceback.print_exc() + + print_dedented( + f""" + Above we output some debug information. + Please provide this info when creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose + WARNING: Please be sure to sanitize sensitive info from the output before posting it. + """, + ) + sys.exit(1) From 3228ca86d74a50d4f7c5170bc473d29c30f3dec5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:15:25 +0530 Subject: [PATCH 85/98] Update kernels_hip.cuh --- csrc/kernels_hip.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index 811299d05..d902129a3 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -103,9 +103,6 @@ __global__ void kOptimizerStatic8bit1StateBlockwise( template __global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); -__global__ void - kHistogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, const int maxidx1, const int n); - template __global__ void kspmm_coo_very_sparse_naive( int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, From 94c1b7751bdd1d10014cf861a4e28ede66262530 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:21:11 +0530 Subject: [PATCH 86/98] Update kernels.hip --- csrc/kernels.hip | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 56e1d54db..53b2725a3 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -346,18 +346,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } -__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) -{ - const int tid = threadIdx.x + (blockDim.x*blockIdx.x); - const int numThreads = blockDim.x*gridDim.x; - - for(int i = tid; i < n; i+=numThreads) - { - int idx = (index1[i]*maxidx1) + index2[i]; - atomicAdd(&histogram[idx], src[i]); - } -} - #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 From cd3f0b779f6c285cd969689dd509ad08698e0964 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:23:14 +0530 Subject: [PATCH 87/98] Update ops.hip --- csrc/ops.hip | 9 --------- 1 file changed, 9 deletions(-) diff --git a/csrc/ops.hip b/csrc/ops.hip index a9c3e0202..ccdbc1026 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -24,15 +24,6 @@ using namespace BinSearch; using std::cout; using std::endl; -void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) -{ - int threads = 512; - int num_blocks = n/threads; - num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; - hipLaunchKernelGGL(( kHistogramScatterAdd2D), dim3(num_blocks), dim3(512), 0, 0, histogram, index1, index2, src, maxidx1, n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - template void estimateQuantiles(T *A, float *code, float offset, int n) { int num_blocks = n/4096; From 98bb06ed6245da3af44497c1df04c8da06f00d2a Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 14:25:32 +0530 Subject: [PATCH 88/98] Update ops_hip.cuh --- csrc/ops_hip.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 624ebe326..ebae292c4 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -160,8 +160,6 @@ void optimizerStatic8bitBlockwise( template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); -void histogramScatterAdd2D(float* histogram, int* index1, int* index2, float* src, int maxidx1, int n); - void gemmex( Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc From 3bad4541e3d9fc186cf680009bfef7c980bb0aaa Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:17:59 +0530 Subject: [PATCH 89/98] Update kernels_hip.cuh --- csrc/kernels_hip.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh index d902129a3..00718071c 100644 --- a/csrc/kernels_hip.cuh +++ b/csrc/kernels_hip.cuh @@ -11,10 +11,6 @@ #ifndef kernels #define kernels -template -__global__ void - kEstimateQuantiles(T* __restrict__ const A, float* code, const float offset, const T max_val, const int n); - __global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); __global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); From e0c766dcc34b6147d5a6e8aa505dbb15c08233a5 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:20:37 +0530 Subject: [PATCH 90/98] Update kernels.hip --- csrc/kernels.hip | 73 ------------------------------------------------ 1 file changed, 73 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 53b2725a3..6b0f1eac5 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -346,79 +346,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } -#define THREADS_ESTIMATE 512 -#define NUM_ESTIMATE 8 -#define BLOCK_ESTIMATE 4096 - -template -__launch_bounds__(THREADS_ESTIMATE, 1) -__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) -{ - const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); - int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; - const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); - const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); - - T vals[NUM_ESTIMATE]; - - typedef hipcub::BlockRadixSort BlockRadixSort; - typedef hipcub::BlockLoad LoadFloat; - - __shared__ union { - typename LoadFloat::TempStorage loadf; - typename BlockRadixSort::TempStorage sort; - int smem_qidx[BLOCK_ESTIMATE]; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) - { - valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; - - // do not process half-blocks - if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } - - #pragma unroll 4 - for(int j = 0; j < NUM_ESTIMATE; j++) - vals[j] = max_val; - - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); - - #pragma unroll 4 - for(int j = 0; j < NUM_ESTIMATE; j++) - vals[j] = ((float)vals[j]) * reciprocal_num_blocks; - - - __syncthreads(); - // sort into striped pattern to mitigate bank conflicts - // striped pattern index for thread 0 [0, 1024, 2048, 3096] - // striped pattern index for thread 1 [1, 1025, 2049, 3097] - BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); - - __syncthreads(); - for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) - temp_storage.smem_qidx[j] = -1; - - __syncthreads(); - - if(threadIdx.x < 256) - { - float q_interval = (1.0f-(2.0f*offset))/255.0f; - int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); - temp_storage.smem_qidx[local_idx] = threadIdx.x; - } - - __syncthreads(); - - for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) - { - if(temp_storage.smem_qidx[i] != -1) - atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); - } - } -} - - __launch_bounds__(TH, 4) __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) { From f35a063db5bd5fb87c0ccf70df2687b7079b33af Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:22:55 +0530 Subject: [PATCH 91/98] Update kernels.hip --- csrc/kernels.hip | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/kernels.hip b/csrc/kernels.hip index 6b0f1eac5..ec3f7f025 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2899,9 +2899,6 @@ template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); -template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); - #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ From fca01f310358169d49b686bce1fae7a9c4d37c93 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:30:34 +0530 Subject: [PATCH 92/98] Update ops.hip --- csrc/ops.hip | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/ops.hip b/csrc/ops.hip index ccdbc1026..1840b7864 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -743,9 +743,6 @@ template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, con template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); -template void estimateQuantiles(half *A, float *code, float offset, int n); -template void estimateQuantiles(float *A, float *code, float offset, int n); - template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); From 5569c2de672006ed6353cf85e0a34b4ddeec59a1 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:34:01 +0530 Subject: [PATCH 93/98] Update ops_hip.cuh --- csrc/ops_hip.cuh | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index ebae292c4..0f8db2ee4 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -124,8 +124,6 @@ class ContextHipsparse { } }; -template void estimateQuantiles(T* A, float* code, float offset, int n); - void quantize(float* code, float* A, unsigned char* out, int n); void dequantize(float* code, unsigned char* A, float* out, int n, hipStream_t stream); template From 7a17f2d6f7ecfb78cf72d94de4b3f3f3ef4e1453 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:44:51 +0530 Subject: [PATCH 94/98] Update ops.hip --- csrc/ops.hip | 9 --------- 1 file changed, 9 deletions(-) diff --git a/csrc/ops.hip b/csrc/ops.hip index 1840b7864..260b74b30 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -24,15 +24,6 @@ using namespace BinSearch; using std::cout; using std::endl; -template void estimateQuantiles(T *A, float *code, float offset, int n) -{ - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); - hipLaunchKernelGGL(( kEstimateQuantiles), dim3(num_blocks), dim3(512), 0, 0, A, code, offset, std::numeric_limits::max(), n); - CUDA_CHECK_RETURN(hipPeekAtLastError()); -} - void quantize(float *code, float *A, unsigned char *out, int n) { int num_blocks = n/1024; From 6b8239e707ba7e63bdf3abbac7d365c0a6a0dbfb Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 16:33:09 +0530 Subject: [PATCH 95/98] Update CMakeLists.txt --- CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8a7583279..770b4ba30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -195,9 +195,6 @@ elseif(BUILD_HIP) string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") - if(HIP_VERSION VERSION_LESS "6.1") - string(APPEND BNB_OUTPUT_NAME "_nohipblaslt") - endif() add_compile_definitions(__HIP_PLATFORM_AMD__) add_compile_definitions(__HIP_PLATFORM_HCC__) add_compile_definitions(BUILD_HIP) From 00ac146878bf64ac12c923aaae7ec00283f0ecde Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 16:48:31 +0530 Subject: [PATCH 96/98] Update functional.py --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3c0a41351..9b446a2de 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -908,7 +908,7 @@ def quantize_4bit( absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. @@ -1019,7 +1019,7 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. From 77f4c7747c6354b841f75442f13c2b595bee1a96 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 17:29:43 +0530 Subject: [PATCH 97/98] Update cextension.py --- bitsandbytes/cextension.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7f5483531..1c5197647 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -23,8 +23,6 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: """ prefix = "rocm" if torch.version.hip else "cuda" - blas_suffix = "_nohipblaslt" if torch.version.hip and cuda_specs.cuda_version_tuple < (6, 1) else "" - library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{blas_suffix}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: From c9fe2845a5bf440dfb32ccd0680f6dda41ad8096 Mon Sep 17 00:00:00 2001 From: MISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com> Date: Fri, 20 Jun 2025 17:43:24 +0530 Subject: [PATCH 98/98] Update cextension.py --- bitsandbytes/cextension.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 1c5197647..bb301e712 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -23,6 +23,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: """ prefix = "rocm" if torch.version.hip else "cuda" + library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}" override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: