diff --git a/.github/scripts/build-rocm.sh b/.github/scripts/build-rocm.sh new file mode 100644 index 000000000..b508fac69 --- /dev/null +++ b/.github/scripts/build-rocm.sh @@ -0,0 +1,21 @@ +#!/bin/bash +declare build_arch +declare build_os +declare rocm_version + +set -xeuo pipefail +bnb_rocm_arch="gfx90a;gfx942;gfx1100" +if [ "${build_os:0:6}" == ubuntu ]; then + image=rocm/dev-ubuntu-22.04:${rocm_version}-complete + echo "Using image $image" + docker run --rm --platform "linux/$build_arch" -i \ + -w /src -v "$PWD:/src" "$image" sh -c \ + "apt-get update \ + && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ + && cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \ + && cmake --build ." +fi + +output_dir="output/${build_os}/${build_arch}" +mkdir -p "${output_dir}" +(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}") diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index d3deb26ee..827c2ffbf 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -102,10 +102,55 @@ jobs: path: output/* retention-days: 7 + build-shared-libs-rocm: + strategy: + matrix: + os: [ubuntu-22.04] + arch: [x86_64] + rocm_version: + ["6.1.2", "6.2.4", "6.3.2"] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Docker multiarch + uses: docker/setup-qemu-action@v3 + - name: Clean up disk space + run: | + sudo rm -rf \ + /usr/share/dotnet \ + /opt/ghc \ + "/usr/local/share/boost" \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/hostedtoolcache \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/swift + - name: Build C++ + run: bash .github/scripts/build-rocm.sh + env: + build_os: ${{ matrix.os }} + build_arch: ${{ matrix.arch }} + rocm_version: ${{ matrix.rocm_version }} + - name: Upload build artifact + uses: actions/upload-artifact@v4 + with: + name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }} + path: output/* + retention-days: 7 + build-wheels: needs: - build-shared-libs - build-shared-libs-cuda + - build-shared-libs-rocm strategy: matrix: os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] @@ -173,6 +218,7 @@ jobs: merge-multiple: true - name: Inspect tmp directory after downloading artifacts + run: | ls -alFR tmp/ WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l) @@ -210,6 +256,7 @@ jobs: - uses: actions/checkout@v4 with: path: repo + - name: Delete old pre-release (if exists) run: | cd repo && gh release delete continuous-release_main --cleanup-tag -y diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b462c45d..770b4ba30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,13 +25,14 @@ endif() # Define included source files set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) +set(HIP_FILES csrc/ops.hip csrc/kernels.hip) set(MPS_FILES csrc/mps_ops.mm) set(METAL_FILES csrc/mps_kernels.metal) # C++ sources are always included list(APPEND SRC_FILES ${CPP_FILES}) -set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") -set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) +set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)") +set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) if(APPLE) @@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") message(FATAL_ERROR "CUDA is not supported on macOS" ) endif() set(BUILD_CUDA ON) + set(BUILD_HIP OFF) + set(BUILD_MPS OFF) +elseif(${COMPUTE_BACKEND} STREQUAL "hip") + if(APPLE) + message(FATAL_ERROR "HIP is not supported on macOS" ) + endif() + set(BUILD_CUDA OFF) + set(BUILD_HIP ON) set(BUILD_MPS OFF) elseif(${COMPUTE_BACKEND} STREQUAL "mps") if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) endif() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS ON) else() set(BUILD_CUDA OFF) + set(BUILD_HIP OFF) set(BUILD_MPS OFF) endif() @@ -160,6 +171,33 @@ if(BUILD_CUDA) string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") add_compile_definitions(BUILD_CUDA) +elseif(BUILD_HIP) + enable_language(HIP) + message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}") + if(DEFINED BNB_ROCM_ARCH) + set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH}) + else() + if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100") + elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + endif() + endif() + message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}") + + list(APPEND SRC_FILES ${HIP_FILES}) + + string(APPEND BNB_OUTPUT_NAME "_rocm") + + # get hip version + execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION) + string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}") + string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}") + + string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}") + add_compile_definitions(__HIP_PLATFORM_AMD__) + add_compile_definitions(__HIP_PLATFORM_HCC__) + add_compile_definitions(BUILD_HIP) elseif(BUILD_MPS) if(NOT APPLE) message(FATAL_ERROR "MPS is only supported on macOS" ) @@ -208,6 +246,41 @@ if(BUILD_CUDA) CUDA_SEPARABLE_COMPILATION ON ) endif() +if(BUILD_HIP) + if(NOT DEFINED ENV{ROCM_PATH}) + set(ROCM_PATH /opt/rocm) + else() + set(ROCM_PATH $ENV{ROCM_PATH}) + endif() + list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + macro(find_package_and_print_version PACKAGE_NAME) + find_package("${PACKAGE_NAME}" ${ARGN}) + message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") + endmacro() + find_package_and_print_version(hipblas REQUIRED) + find_package_and_print_version(hiprand REQUIRED) + find_package_and_print_version(hipsparse REQUIRED) + + ## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies) + set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "") + set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "") + + target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include) + target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib) + target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse) + + target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP) + set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP) + set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX) + + if(HIP_VERSION VERSION_LESS "6.1") + target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT) + else() + find_package(hipblaslt) + target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt) + endif() +endif() if(BUILD_MPS) add_dependencies(bitsandbytes metallib) target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index c266f61a0..13359bbd8 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -8,7 +8,7 @@ from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from ..._ops import register_kernel -from ...cextension import lib +from ...cextension import HIP_ENVIRONMENT, lib @register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @@ -210,7 +210,12 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") n = A.numel() @@ -264,7 +269,11 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( dtype in [torch.float16, torch.bfloat16, torch.float32], @@ -294,7 +303,11 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( A.dtype in [torch.bfloat16, torch.float16, torch.float32], @@ -372,7 +385,11 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + if HIP_ENVIRONMENT: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) + else: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) torch._check( dtype in [torch.bfloat16, torch.float16, torch.float32], diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index b112df2f7..bb301e712 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -9,7 +9,7 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple +from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch logger = logging.getLogger(__name__) @@ -28,6 +28,11 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) + if torch.version.hip: + raise RuntimeError( + f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n" + f"Clear the variable and retry: export BNB_CUDA_VERSION=\n" + ) logger.warning( f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" @@ -75,10 +80,11 @@ def __init__(self, lib: ct.CDLL): def get_available_cuda_binary_versions() -> list[str]: """Get formatted CUDA versions from existing library files using cuda_specs logic""" - lib_pattern = f"libbitsandbytes_cuda*{DYNAMIC_LIBRARY_SUFFIX}" + lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}" versions = [] for lib in Path(__file__).parent.glob(lib_pattern): - match = re.search(r"cuda(\d{3})", lib.name) + pattern = rf"{BNB_BACKEND.lower()}(\d+)" + match = re.search(pattern, lib.name) if match: ver_code = int(match.group(1)) major = ver_code // 10 @@ -89,8 +95,8 @@ def get_available_cuda_binary_versions() -> list[str]: def parse_cuda_version(version_str: str) -> str: """Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')""" - if version_str.isdigit() and len(version_str) == 3: - return f"{version_str[:2]}.{version_str[2]}" + if version_str.isdigit(): + return f"{version_str[:-1]}.{version_str[-1]}" return version_str # fallback as safety net @@ -151,7 +157,7 @@ def _format_lib_error_message( """Format detailed error message for library loading failures""" analysis = "" no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error - no_cuda_lib_found = "CUDA binary not found" in original_error + no_cuda_lib_found = f"{BNB_BACKEND} binary not found" in original_error if no_cpu_lib_found: analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n" @@ -160,9 +166,9 @@ def _format_lib_error_message( version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE" analysis = ( ( - f"\n🚨 CUDA VERSION MISMATCH 🚨\n" - f"Requested CUDA version: {requested_version}\n" - f"Detected PyTorch CUDA version: {user_cuda_version}\n" + f"\n🚨 {BNB_BACKEND} VERSION MISMATCH 🚨\n" + f"Requested {BNB_BACKEND} version: {requested_version}\n" + f"Detected PyTorch {BNB_BACKEND} version: {user_cuda_version}\n" f"Available pre-compiled versions: {version_list_str}\n\n" "This means:\n" "The version you're trying to use is NOT distributed with this package\n\n" @@ -177,42 +183,47 @@ def _format_lib_error_message( troubleshooting = ( ( - "This typically happens when:\n" - "1. bitsandbytes doesn't ship with a pre-compiled binary for your CUDA version\n" - "2. The library wasn't compiled properly during installation from source\n\n" + f"This typically happens when:\n" + f"1. bitsandbytes doesn't ship with a pre-compiled binary for your {BNB_BACKEND} version\n" + f"2. The library wasn't compiled properly during installation from source\n\n" ) if no_cuda_lib_found - else "This typically happens when you checked the code out from source and your torch installation doesn't detect CUDA on your machine.\n\n" + else f"This typically happens when you checked the code out from source and your torch installation doesn't detect {BNB_BACKEND} on your machine.\n\n" ) note = ( ( - "To make bitsandbytes work, the compiled library version MUST exactly match the linked CUDA version.\n" - "If your CUDA version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" + f"To make bitsandbytes work, the compiled library version MUST exactly match the linked {BNB_BACKEND} version.\n" + f"If your {BNB_BACKEND} version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" ) if no_cuda_lib_found else "" ) compile_instructions = ( - ( + ("COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n") + if not no_cuda_lib_found + else ( "You have two options:\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" ) - if no_cuda_lib_found - else "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" + if not HIP_ENVIRONMENT + else ( + "You can COMPILE FROM SOURCE as mentioned here:\n" + " https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n" + ) ) diagnostics = ( - "šŸ” Run this command for detailed diagnostics:\n" - "python -m bitsandbytes\n\n" - "If you've tried everything and still have issues:\n" - "1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" - "2. Describe what you've tried in detail\n" - "3. Open an issue with this information:\n" - " https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" + f"šŸ” Run this command for detailed diagnostics:\n" + f"python -m bitsandbytes\n\n" + f"If you've tried everything and still have issues:\n" + f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n" + f"2. Describe what you've tried in detail\n" + f"3. Open an issue with this information:\n" + f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" ) return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}" @@ -227,18 +238,19 @@ def _format_dependency_error(self) -> str: ) return ( - f"\n🚨 CUDA SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" - f"CUDA {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" + f"\n🚨 {BNB_BACKEND} SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" + f"{BNB_BACKEND} {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" f"To fix this, make sure that:\n" - f"1. You have installed CUDA {cuda_major_version}.x toolkit on your system\n" - f"2. The CUDA runtime libraries are in your LD_LIBRARY_PATH\n\n" + f"1. You have installed {BNB_BACKEND} {cuda_major_version}.x toolkit on your system\n" + f"2. The {BNB_BACKEND} runtime libraries are in your LD_LIBRARY_PATH\n\n" f"You can add them with (and persist the change by adding the line to your .bashrc):\n" - f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/cuda-{cuda_major_version}.x/lib64\n\n" + f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/{BNB_BACKEND.lower()}-{cuda_major_version}.x/\ + {'lib64' if not HIP_ENVIRONMENT else 'lib'}\n\n" f"Original error: {self.error_msg}\n\n" f"šŸ” Run this command for detailed diagnostics:\n" f"python -m bitsandbytes\n\n" f"If you've tried everything and still have issues:\n" - f"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" + f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n" f"2. Describe what you've tried in detail\n" f"3. Open an issue with this information:\n" f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" @@ -267,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary: cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) if not cuda_binary_path.exists(): - raise RuntimeError(f"Configured CUDA binary not found at {cuda_binary_path}") + raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}") binary_path = cuda_binary_path @@ -286,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary: return BNBNativeLibrary(dll) +ROCM_GPU_ARCH = get_rocm_gpu_arch() + try: # to support Intel CPU/GPU (XPU) backend import intel_extension_for_pytorch as ipex @@ -296,8 +310,12 @@ def get_native_library() -> BNBNativeLibrary: ipex_cpu = None ipex_xpu = None - try: + if torch.version.hip: + HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm" + else: + HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA" + lib = get_native_library() except Exception as e: error_msg = str(e) diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 64903cd49..32563a159 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -1,5 +1,8 @@ import dataclasses from functools import lru_cache +import logging +import re +import subprocess from typing import Optional import torch @@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]: ) except Exception: return None + + +def get_rocm_gpu_arch() -> str: + """Get ROCm GPU architecture.""" + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index e763ef206..29a9a66e1 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -5,7 +5,7 @@ import torch -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.diagnostics.utils import print_dedented @@ -32,9 +32,13 @@ } CUDA_RUNTIME_LIB_PATTERNS = ( - "cudart64*.dll", # Windows - "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. - "nvcuda*.dll", # Windows + ("libamdhip64.so*",) + if HIP_ENVIRONMENT + else ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) ) logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path pass for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: for pth in dir.glob(lib_pattern): - if pth.is_file(): + if pth.is_file() and not pth.is_symlink(): yield pth except (OSError, PermissionError): pass @@ -103,7 +107,7 @@ def find_cudart_libraries() -> Iterator[Path]: yield from find_cuda_libraries_in_path_list(value) -def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print( f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", @@ -128,7 +132,37 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ) -def print_cuda_runtime_diagnostics() -> None: +def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") + + binary_path = get_cuda_bnb_library_path(cuda_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCm version + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) + + hip_major, hip_minor = cuda_specs.cuda_version_tuple + if (hip_major, hip_minor) < (6, 1): + print_dedented( + """ + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + """, + ) + + +def print_diagnostics(cuda_specs: CUDASpecs) -> None: + if HIP_ENVIRONMENT: + _print_hip_diagnostics(cuda_specs) + else: + _print_cuda_diagnostics(cuda_specs) + + +def _print_cuda_runtime_diagnostics() -> None: cudart_paths = list(find_cudart_libraries()) if not cudart_paths: print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") @@ -153,3 +187,33 @@ def print_cuda_runtime_diagnostics() -> None: ) for pth in cudart_paths: print(f"* Found CUDA runtime at: {pth}") + + +def _print_hip_runtime_diagnostics() -> None: + cudart_paths = list(find_cudart_libraries()) + if not cudart_paths: + print("WARNING! ROCm runtime files not found in any environmental path.") + elif len(cudart_paths) > 1: + print_dedented( + f""" + Found duplicate ROCm runtime files (see below). + + We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + but this might mismatch with the ROCm version that is needed for bitsandbytes. + + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + """, + ) + + for pth in cudart_paths: + print(f"* Found ROCm runtime at: {pth}") + + +def print_runtime_diagnostics() -> None: + if HIP_ENVIRONMENT: + _print_hip_runtime_diagnostics() + else: + _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index aa4cb3042..74da662b6 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -6,10 +6,11 @@ import torch from bitsandbytes import __version__ as bnb_version +from bitsandbytes.cextension import BNB_BACKEND from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.diagnostics.cuda import ( - print_cuda_diagnostics, + print_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header @@ -77,19 +78,19 @@ def main(): cuda_specs = get_cuda_specs() if cuda_specs: - print_cuda_diagnostics(cuda_specs) + print_diagnostics(cuda_specs) # TODO: There's a lot of noise in this; needs improvement. # print_cuda_runtime_diagnostics() if not torch.cuda.is_available(): - print("PyTorch says CUDA is not available. Possible reasons:") - print("1. CUDA driver not installed") + print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:") + print(f"1. {BNB_BACKEND} driver not installed") print("2. Using a CPU-only PyTorch build") print("3. No GPU detected") else: - print("Checking that the library is importable and CUDA is callable...") + print(f"Checking that the library is importable and {BNB_BACKEND} is callable...") try: sanity_check() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py old mode 100755 new mode 100644 index 6893752c9..9b446a2de --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import ipex_cpu, ipex_xpu, lib +from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib name2qmap = {} @@ -868,10 +868,12 @@ def quantize_fp4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -879,10 +881,12 @@ def quantize_nf4( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_storage=torch.uint8, ): + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -890,7 +894,7 @@ def quantize_4bit( A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize=64, + blocksize=None, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, @@ -904,7 +908,7 @@ def quantize_4bit( absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. @@ -918,6 +922,10 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ + + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + input_shape = A.shape _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( @@ -968,8 +976,10 @@ def dequantize_fp4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -978,8 +988,10 @@ def dequantize_nf4( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, ) -> torch.Tensor: + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -988,7 +1000,7 @@ def dequantize_4bit( quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, quant_type="fp4", ) -> torch.Tensor: """Dequantizes a packed 4-bit quantized tensor. @@ -1007,7 +1019,7 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 64. + The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. @@ -1017,6 +1029,10 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ + + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + if quant_state is None: assert absmax is not None and out is not None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1aed09219..ba134f52a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,6 +11,7 @@ import torch.nn.functional as F import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( @@ -213,7 +214,7 @@ def __new__( data: Optional[torch.Tensor] = None, requires_grad=False, # quantized weights should be frozen by default quant_state: Optional[QuantState] = None, - blocksize: int = 64, + blocksize: Optional[int] = None, compress_statistics: bool = True, quant_type: str = "fp4", quant_storage: torch.dtype = torch.uint8, @@ -223,6 +224,9 @@ def __new__( if data is None: data = torch.empty(0) + if blocksize is None: + blocksize = 64 if not HIP_ENVIRONMENT else 128 + self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize self.compress_statistics = compress_statistics diff --git a/csrc/common_hip.cuh b/csrc/common_hip.cuh new file mode 100644 index 000000000..1d9d9afe0 --- /dev/null +++ b/csrc/common_hip.cuh @@ -0,0 +1,7 @@ +#pragma once + +#define BNB_WARP_SIZE warpSize + +// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs +#define BNB_MAX_THREADS_PER_SM 2048 +#define BNB_BF16_AVAILABLE true diff --git a/csrc/kernels.hip b/csrc/kernels.hip new file mode 100644 index 000000000..ec3f7f025 --- /dev/null +++ b/csrc/kernels.hip @@ -0,0 +1,3165 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include "kernels_hip.cuh" +#include "common_hip.cuh" +#include +#include +#include + +//#include + + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +// Luckily we have atomicmax and atomicmin in ROCm + + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assume input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to notice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + + +__device__ __forceinline__ float dDequantizeNF4(unsigned char val) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabsf(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabsf(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +__launch_bounds__(TH, 4) +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; + + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) +{ + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; + + for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); + + if(threadIdx.x == 0) { + smem_absmax_value[0] = 1.0f / local_abs_max; + absmax[i / BLOCK_SIZE] = local_abs_max; + } + __syncthreads(); + + local_abs_max = smem_absmax_value[0]; + + if(STOCHASTIC) + { + local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + } +} + +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockStore 0) ? 2 : 1), hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + if (DATA_TYPE > 0) + { + valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); + valid_items_store = min(TILE_SIZE * 2, n - i * 2); + } + else + { + valid_items_load = min(TILE_SIZE, n - i); + valid_items_store = valid_items_load; + } + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch (DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + } +} + +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) +{ + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + __shared__ float smem_code[256]; + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + } + + __syncthreads(); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - powf(beta1, step)); + const float correction2 = 1.0f/(1.0f - powf(beta2, step)); + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + + + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + // AdEMAMix has an additional state buffer, which we packed + // into state1. We need thread-local storage here for these. + // TODO: Mark with [[maybe_unused]] after upgrade to min compiler. + float s3_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + // Load additional state1 data for AdEMAMix + // TODO: Make constexpr after updating min compiler + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADEMAMIX: + // m1 update: m1 = beta1 * m1 + (1-beta1) * g + s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]); + + // m2 update: m2 = m2 * beta3 + (1-beta3) * g + s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]); + + // nu update: nu = beta2 * nu + (1-beta2) * g^2 + s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]); + + p_vals[j] = (float)p_vals[j] - lr * ( + ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( + (sqrtf(s2_vals[j]) / correction2) + eps + ) + ); + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay)); + + break; + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items); + } + } +} + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + //__syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; + + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if(threadIdx.x < 256) + { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items); + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + } + + if(threadIdx.x == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS2, 1) +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 512) + { + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case ADAGRAD: + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); + if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); + if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + } + +} + +template +__global__ void +__launch_bounds__(1024, 1) +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case ADAGRAD: + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER){ + case ADAGRAD: + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if(threadIdx.x == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } + else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit2StateBlockwise( + T* p, + T* __restrict__ const g, + unsigned char* state1, + unsigned char* state2, + const float beta1, + const float beta2, + const float beta3, + const float alpha, + const float eps, + const int step, + const float lr, + float* __restrict__ const quantiles1, + float* __restrict__ const quantiles2, + float* absmax1, + float* absmax2, + float weight_decay, + const float gnorm_scale, + const bool skip_zeros, + const int n +) { + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + float s3_vals[N_PER_TH]; + + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f -__powf(beta2, step)); + const float step_size = __fdividef(-lr*correction2,correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float new_local_abs_max3 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + unsigned char c3s[N_PER_TH]; + + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + typedef hipcub::BlockReduce BlockReduce2; + typedef hipcub::BlockReduce BlockReduce3; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ typename BlockReduce2::TempStorage reduce3; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + __shared__ float smem_exchange3[1]; // [[maybe_unused]] + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + // AdEMAMix has an additional state packed into state1. + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128); + } + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + new_local_abs_max3 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + if (OPTIMIZER == ADEMAMIX) { + // The absmax for the third state is appended to absmax1 + s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE]; + s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val)); + } + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + + if (OPTIMIZER == ADEMAMIX) { + s3_vals[j] = 0.0f; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j])); + } + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, hipcub::Max()); + } + + if(threadIdx.x == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + smem_exchange3[0] = new_local_abs_max3; + } + } + + __syncthreads(); + + if(threadIdx.x == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + + if (OPTIMIZER == ADEMAMIX) { + absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3; + } + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + + if (OPTIMIZER == ADEMAMIX) { + new_local_abs_max3 = smem_exchange3[0]; + } + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + if (OPTIMIZER == ADEMAMIX) { + p_vals[j] = T((float)p_vals[j] - lr * ( + ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( + (sqrtf(s2_vals[j]) / correction2) + eps + ) + )); + } else { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + } + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + if (OPTIMIZER == ADEMAMIX) { + c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3)); + + if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) { + c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1; + } + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + + if (OPTIMIZER == ADEMAMIX) { + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items); + } + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; + + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef hipcub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + + if(threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +// Inputs: +// A [rows, cols] +// Outputs: +// rowStats [rows] +// out [rows, cols] +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) { + + // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32. + // Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped. +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE + using TReduction = T; +#else + using TReduction = float; +#endif + + using BlockReduceT = hipcub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + __shared__ TReduction smem_row_absmax; + + const int row_id = blockIdx.x; + const T* row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + TReduction row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const TReduction absval = fabsf(__ldcs(&(row_data[i]))); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = smem_row_absmax = row_absmax; + } + __syncthreads(); + + // Quantize row-wise. + const float scale = __fdividef(127.0f, smem_row_absmax); + for (int i = threadIdx.x; i < cols; i += THREADS) { + float val = row_data[i]; + + if constexpr (SPARSE_DECOMP) { + // For sparse decomposition, we do not want to quantize the outliers. + // Instead they're zeroed out. + out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0; + } else { + out[row_id * cols + i] = __float2int_rn(val * scale); + } + } +} + +template +__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) +__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) { + using BlockReduceT = hipcub::BlockReduce; + + // One block per row. + // Threads load column values in a striped arrangement. + // e.g. t0 reads row[0], row[0+nthreads], .. + // and t1 reads row[1], row[1+nthreads], .. + // Each thread will determine its local absmax. + // We then do a blockwise reduction to determine the row's absmax. + + __shared__ typename BlockReduceT::TempStorage temp_storage; + + const int row_id = blockIdx.x; + const T* __restrict__ row_data = A + (row_id * cols); + + // Threads will read the row values in a striped access pattern and find a local absmax. + float row_local_absmax = -FLT_MIN; + for (int i = threadIdx.x; i < cols; i += THREADS) { + const float absval = fabsf(row_data[i]); + + // For sparse decomposition, values outside of the threshold are not to be + // included when calculating the row's absmax. + if constexpr (SPARSE_DECOMP) { + row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax); + } else { + row_local_absmax = fmaxf(row_local_absmax, absval); + } + } + + // Reduce thread-local absmax across the block. + // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols); + if (threadIdx.x == 0) { + // Save our block's absmax to shared memory for the quantization step. + rowStats[row_id] = row_absmax; + } +} + +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); +template __global__ void kgetRowStats(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols); + +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); +template __global__ void kInt8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols); + + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, + float *__restrict__ const rowStats, + float *__restrict__ const colStats, + half *out, + half *__restrict__ const bias, + const int numRows, + const int numCols, + const int n +) { + const int n_out = numRows * numCols; + + int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD; + int thread_offset = threadIdx.x * ITEMS_PER_THREAD; + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + + float local_rowStats[ITEMS_PER_THREAD]; + float local_colStats[ITEMS_PER_THREAD]; + float local_biasValue[ITEMS_PER_THREAD]; + + typedef hipcub::BlockLoad LoadInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + + int row_idx, col_idx; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + row_idx = (block_offset + thread_offset + j) / numCols; + col_idx = (block_offset + thread_offset + j) % numCols; + + local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx]; + local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx]; + local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]); + } + + // Each block loads THREADS * ITEMS_PER_THREAD values from A + int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out + ? THREADS * ITEMS_PER_THREAD + : n_out - block_offset; + LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0); + + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; ++j) { + local_output[j] = __float2half( + fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j]) + ); + } + + #pragma unroll ITEMS_PER_THREAD + for (int j = 0; j < ITEMS_PER_THREAD; j++) { + int outIdx = block_offset + thread_offset + j; + if (outIdx < n_out) { + out[outIdx] = local_output[j]; + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +#define WARP_SIZE warpSize +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int warp_idx = threadIdx.x % WARP_SIZE; + const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*WARP_SIZE apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + +#define WARPS 3 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + + +template __device__ void printnonzero(T *A, int num_values, const char * strval) +{ + for(int i = 0; i < num_values; i++) + if((float)A[i] != 0.0) + printf("%s %i %f\n", strval, i, (float)A[i]); +} + +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + //// element-wise kernel + //// 1. Load batch x k into registers + //// 2. Load k x k into registers + //// 3. dequantize and store in second pair of k x k + //// 4. matmul + //// 5. sum with cub + //// 6. store outputs + //// TC kernel + //// use k warps per thread block + //// 1. threadblock use read-only cache to read in register tile for A into shared memory + //// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments + //// 3. each warp reads a segment of values 16x32 from B + //// 4. do dequantization from register of B into second pair of registers + //// 5. store (4) into fragment + //// 6. matmul aggregate into fragment C + //// 7. aggregate files of C into shared memory block C + //// 8. sum (7) + //// 9. write outputs to matmul output matrix +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T quant_map[16]; + + #pragma unroll 16 + for(int i = 0; i < 16; i++) + quant_map[i] = nf4_data[i]; + //__shared__ T quant_map[16*160]; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_C[8*32]; + + rocwmma::fragment a_frag; + rocwmma::fragment b_frag; + rocwmma::fragment c_frag; + rocwmma::fill_fragment(c_frag, 0.0f); + + for(int i = threadIdx.x; i < (8*32); i+=blockDim.x) + smem_C[i] = 0.0f; + + __syncthreads(); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + //local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + //local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0); + //local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0); + + //local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0); + //local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0); + local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0); + local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + //if(threadIdx.x == 0) + //printf("aa %i %i\n", idx, loaded_values); + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + //if(threadIdx.x == 0) + //printf("%i %i\n", idx, loaded_values); + + //__syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + //local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + //local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx); + //local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx); + + //local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax); + //local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax); + local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx); + local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx); + } + //printnonzero(local_B, 128, ""); + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + //if(threadIdx.x == 0) + //{ + // printnonzero(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: "); + // printnonzero(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: "); + //} + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + //if(warp_lane == 0) + //printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x); + rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major); + + //printnonzero(smem_C, 32, ""); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_C[warp_lane]; +#endif +} + +// No of 4bit values processed by each thread +#define num_values_4bit 32 +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + // per threadblock: + // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block + typedef hipcub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; + + const int warp_idx = threadIdx.x / warpSize; + const int warp_lane = threadIdx.x % warpSize; + const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; + const int offset_B = ldb*row_B; + const int num_values_8bit = num_values_4bit/2; + float local_C = 0.0f; + + unsigned char local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit/4]; + T local_A[num_values_4bit/4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + if (threadIdx.x < 16) + quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); + //for(int i = threadIdx.x; i < 16; i++) + //quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [M, K] + for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) + { + const int inner_idx_halved = inner_idx/2; + + // Since blocksize will always be a power-of-2, we avoid more expensive + // division by the blocksize and instead use a shift operation. + // This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize. + const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize)); + + local_absmax = __ldg(&(absmax[absidx])); + + if(row_B < M) + { + if((inner_idx_halved + num_values_8bit) < (K/2)) + { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = reinterpret_cast(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)]; + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + if((inner_idx_halved) + j < (K/2)) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } + else + { + #pragma unroll + for(int j = 0; j < (num_values_8bit); j++) + local_B_4bit[j] = 0b01110111; + } + + for(int i = 0; i < 4; i++) + { + #pragma unroll + for(int k = 0; k < num_values_8bit/4; k++) + { + #if BNB_BF16_AVAILABLE + local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; + local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; + #else + // bf16 multipliation not supported + local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax); + local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax); + #endif + } + + if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K) + { + // this is also relatively important for performance + if(BITS==16) + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/4) + i]; + } + else + { + reinterpret_cast(local_A)[0] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0]; + reinterpret_cast(local_A)[1] = reinterpret_cast(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1]; + } + + } + else + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + if(inner_idx + (i*num_values_4bit/4) + k < K) + local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)]; + else + local_A[k] = T(0.0f); + + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for(int k = 0; k < num_values_4bit/4; k++) + { + #if BNB_BF16_AVAILABLE + local_C += (float)(local_A[k]*local_B[k]); + #else + // bf16 multipliation not supported + local_C += ((float)local_A[k]*(float)local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if(row_B < M && warp_lane == 0) + out[row_B] = T(local_C); + +} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, hip_bfloat16) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(MOMENTUM, hip_bfloat16) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(RMSPROP, hip_bfloat16) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, hip_bfloat16) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) +MAKE_Optimizer32bit1State(ADAGRAD, hip_bfloat16) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) +MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, hip_bfloat16) + +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) +MAKE_PreconditionStatic8bit1State(ADAGRAD, half) +MAKE_PreconditionStatic8bit1State(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) +MAKE_optimizerStatic8bit1State(ADAGRAD, half) +MAKE_optimizerStatic8bit1State(ADAGRAD, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); +template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4) + +MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4) +//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, const float beta3, const float alpha, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1) +MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, hip_bfloat16, 256, 1) + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, hip_bfloat16, 256, 1) + +template __device__ void printnonzero(float *A, int num_values, const char*strval); +template __device__ void printnonzero(half *A, int num_values, const char*strval); diff --git a/csrc/kernels_hip.cuh b/csrc/kernels_hip.cuh new file mode 100644 index 000000000..00718071c --- /dev/null +++ b/csrc/kernels_hip.cuh @@ -0,0 +1,139 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#ifndef kernels +#define kernels + +__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n); +__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n); + +template +__global__ void kQuantizeBlockwise( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +); +template +__global__ void + kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit2State( + T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, + const int n +); + +template +__global__ void kPreconditionOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps, + const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizer32bit1State( + T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1, + const float beta2, const float eps, const float weight_decay, const int step, const float lr, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit1State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1, + const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1, + float* new_max1, const float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit1State( + T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale, + const int n +); + +template +__global__ void kPreconditionOptimizerStatic8bit2State( + T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2, + float* unorm, const float beta1, const float beta2, const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2State( + T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm, + const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, + float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n +); + +template +__global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, + const float beta3, const float alpha, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps, + const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n +); + +template +__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n); + +template +__global__ void kspmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB +); + +template +__global__ void kdequant_mm_int32_fp16( + int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out, + half* __restrict__ const bias, const int numRows, const int numCols, const int n +); + +template +__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols); +template +__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols); + +template +__global__ void kTransformRowToFormat( + char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols +); + +template +__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc); +template +__global__ void kgemm_4bit_inference( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, + int blocksize +); +template +__global__ void kgemm_4bit_inference_naive( + int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out, + int lda, int ldb, int ldc, int blocksize +); + +template __global__ void kfunc(T* A, T* B, T value, long n); + +#endif diff --git a/csrc/ops.hip b/csrc/ops.hip new file mode 100644 index 000000000..260b74b30 --- /dev/null +++ b/csrc/ops.hip @@ -0,0 +1,835 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#ifndef NO_HIPBLASLT +#include +#endif +#include +#include +#include +#include + +#define ERR_NOT_IMPLEMENTED 100 + +using namespace BinSearch; +using std::cout; +using std::endl; + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kDequantize), dim3(num_blocks), dim3(1024), 0, stream, code, A, out, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if(blocksize == 4096) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n); + //else if(blocksize == 64) + // hipLaunchKernelGGL(( kQuantizeBlockwise), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n); + + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, hipStream_t stream) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize/2, n); + else + hipLaunchKernelGGL(( kDequantizeBlockwise), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + + + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float beta3, const float alpha, + const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + case ADEMAMIX: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + hipLaunchKernelGGL(( kOptimizer32bit2State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update after the parameter update + hipLaunchKernelGGL(( kOptimizer32bit1State), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + } + break; + } +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit2State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + hipLaunchKernelGGL(( kOptimizerStatic8bit1State), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 256 +#define NUM_2STATE 1 +#define BLOCKSIZE_1STATE 256 +#define NUM_1STATE 1 + +template void optimizerStatic8bitBlockwise( + T* p, + T* g, + unsigned char* state1, + unsigned char* state2, + float beta1, + float beta2, + float beta3, + float alpha, + float eps, + int step, + float lr, + float* quantiles1, + float* quantiles2, + float* absmax1, + float* absmax2, + float weight_decay, + const float gnorm_scale, + bool skip_zeros, + int n +) { + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + case ADEMAMIX: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); + break; + } +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + hipLaunchKernelGGL(( kPercentileClipping), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + +#if hipblasVersionMajor >= 3 + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta, + C, HIP_R_32I, ldc, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else + status = hipblasGemmEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, + C, HIPBLAS_R_32I, ldc, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + hipblasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + +#if hipblasVersionMajor >= 3 + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta, + C, HIP_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT); +#else + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, + C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); +#endif + + if (status != HIPBLAS_STATUS_SUCCESS) + { + std::cout << "HIPBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + +#ifdef NO_HIPBLASLT +#else +template hipblasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return HIPBLASLT_ORDER_ROW; + break; + case COL: + return HIPBLASLT_ORDER_COL; + break; + case COL32: + //return HIPBLASLT_ORDER_COL32; + return HIPBLASLT_ORDER_COL; + break; + case COL_TURING: + //return HIPBLASLT_ORDER_COL4_4R2_8C; + return HIPBLASLT_ORDER_COL; + break; + case COL_AMPERE: + //return HIPBLASLT_ORDER_COL32_2R_4R4; + return HIPBLASLT_ORDER_COL; + break; + default: + break; + } + + return HIPBLASLT_ORDER_ROW; +} + +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +//template hipblasLtOrder_t get_order(); +#endif + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + default: + return dim1; + break; + /*case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + */ + } +} + +static std::string hipError_to_string(const hipError_t ret) +{ + switch(ret) + { + case hipSuccess: + return "hipSuccess"; + case hipErrorInvalidContext: + return "hipErrorInvalidContext"; + case hipErrorInvalidKernelFile: + return "hipErrorInvalidKernelFile"; + case hipErrorMemoryAllocation: + return "hipErrorMemoryAllocation"; + case hipErrorInitializationError: + return "hipErrorInitializationError"; + case hipErrorLaunchFailure: + return "hipErrorLaunchFailure"; + case hipErrorLaunchOutOfResources: + return "hipErrorLaunchOutOfResources"; + case hipErrorInvalidDevice: + return "hipErrorInvalidDevice"; + case hipErrorInvalidValue: + return "hipErrorInvalidValue"; + case hipErrorInvalidDevicePointer: + return "hipErrorInvalidDevicePointer"; + case hipErrorInvalidMemcpyDirection: + return "hipErrorInvalidMemcpyDirection"; + case hipErrorUnknown: + return "hipErrorUnknown"; + case hipErrorInvalidResourceHandle: + return "hipErrorInvalidResourceHandle"; + case hipErrorNotReady: + return "hipErrorNotReady"; + case hipErrorNoDevice: + return "hipErrorNoDevice"; + case hipErrorPeerAccessAlreadyEnabled: + return "hipErrorPeerAccessAlreadyEnabled"; + case hipErrorPeerAccessNotEnabled: + return "hipErrorPeerAccessNotEnabled"; + case hipErrorRuntimeMemory: + return "hipErrorRuntimeMemory"; + case hipErrorRuntimeOther: + return "hipErrorRuntimeOther"; + case hipErrorHostMemoryAlreadyRegistered: + return "hipErrorHostMemoryAlreadyRegistered"; + case hipErrorHostMemoryNotRegistered: + return "hipErrorHostMemoryNotRegistered"; + case hipErrorMapBufferObjectFailed: + return "hipErrorMapBufferObjectFailed"; + case hipErrorTbd: + return "hipErrorTbd"; + default: + throw std::runtime_error("unknown hipError"); + } +} + +template int igemmlt( + hipblasLtHandle_t ltHandle, + int m, int n, int k, + const int8_t *A, + const int8_t *B, + void *C, + float *row_scale, + int lda, int ldb, int ldc, + hipStream_t stream +) { +#ifdef NO_HIPBLASLT + return ERR_NOT_IMPLEMENTED; +#else + + // Calculate C = A^T @ B, in col-major layout. + // + // Use the IMMA kernels requires: + // * A must be transposed and B must be non-transposed. + // * Dimensions m and k must be multiples of 4. + // * All pointers must be 4-byte aligned; 16-byte alignment preferred. + + int has_error = 0; + const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel + + hipblasLtMatmulDesc_t matmulDesc; + hipblasLtMatrixLayout_t aDesc, bDesc, cDesc; + hipblasOperation_t opT = HIPBLAS_OP_T; + + hipDataType outType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_8I; + hipDataType scaleType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_32F; + + hipblasLtPointerMode_t pointerMode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&aDesc, HIP_R_8I, m, k, lda)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&bDesc, HIP_R_8I, m, n, ldb)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc)); + + // Default layout order is col major + + has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, scaleType)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT))); + + if (DTYPE_OUT == 32) { + + /* Algo and workspace TODO: need to rework to not be duplicated */ + // Set User Preference attributes + hipblasLtMatmulPreference_t pref; + checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref)); + checkHipblasStatus( + hipblasLtMatmulPreferenceSetAttribute(pref, + HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &max_workspace_size, + sizeof(max_workspace_size))); + + const int request_solutions = 1; + hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions]; + int returnedAlgoCount = 0; + checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle, + matmulDesc, + aDesc, + bDesc, + cDesc, + cDesc, + pref, + request_solutions, + heuristicResult, + &returnedAlgoCount)); + + if (returnedAlgoCount == 0) + { + has_error = 1; + fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n"); + } else { + int alpha = 1, beta = 0; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int32_t*)C, cDesc, + (int32_t*)C, cDesc, + &heuristicResult[0].algo, NULL, 0, stream + )); + } + } else { + // This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows. + + if (!SCALE_ROWS) { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + &alpha, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } else { + hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST; + float beta = 0.0f; + has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute( + matmulDesc, + HIPBLASLT_MATMUL_DESC_POINTER_MODE, + &pointerMode, + sizeof(alphaVec) + )); + has_error |= checkHipblasStatus(hipblasLtMatmul( + ltHandle, matmulDesc, + row_scale, A, aDesc, + B, bDesc, &beta, + (int8_t*)C, cDesc, + (int8_t*)C, cDesc, + NULL, NULL, 0, stream + )); + } + } + + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(cDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(bDesc)); + has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(aDesc)); + has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc)); + + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif // NO_HIPBLASLT +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, hipStream_t stream) +{ + const int threads = 512; + const int num_per_thread = 4; + const int num_per_block = threads * num_per_thread; + const int n = numRows*numCols; + const int num_blocks = (n + num_per_block - 1) / num_per_block; + + hipLaunchKernelGGL(( kdequant_mm_int32_fp16), dim3(num_blocks), dim3(threads), 0, stream, A, rowStats, colStats, out, bias, numRows, numCols, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { if (threshold == 0.0) { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } else { + kInt8VectorQuant<<>>(A, out, rowStats, threshold, rows, cols); + } + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { + if (threshold == 0.0) + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + else + kgetRowStats<<>>(A, rowStats, threshold, rows, cols); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + +#ifdef NO_HIPBLASLT +#else + + hipsparseSpMatDescr_t descA; + hipsparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + HIPSPARSE_INDEX_32I, + HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) ); + // Create dense matrix C + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + HIP_R_16F, HIPSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_HIPSPARSE( hipsparseSpMM_bufferSize( + handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_HIPSPARSE( hipsparseSpMM(handle, + HIPSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, HIP_R_32F, + HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) ); + CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( hipFree(dBuffer) ); +#endif +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + if(bits == 32) + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + hipLaunchKernelGGL(( gemm_device), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + hipLaunchKernelGGL(( kgemm_4bit_inference), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream) +{ + + //warpsize - 32 + int num_blocks = (m+3)/4; + //warpsize - 64 + if (warpSize == 64) { + num_blocks = (m+1)/2; + } + + hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, stream, m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + hipLaunchKernelGGL(( kfunc), dim3(blocks), dim3(512), 0, 0, A, B, value, n); + CUDA_CHECK_RETURN(hipPeekAtLastError()); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream); + +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); +template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, hip_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(MOMENTUM, hip_bfloat16) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(RMSPROP, hip_bfloat16) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, hip_bfloat16) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) +MAKE_optimizer32bit(ADAGRAD, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, half) +MAKE_optimizer32bit(ADEMAMIX, hip_bfloat16) +MAKE_optimizer32bit(ADEMAMIX, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) +MAKE_optimizerStatic8bit(ADAGRAD, half) +MAKE_optimizerStatic8bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADEMAMIX); +MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh new file mode 100644 index 000000000..0f8db2ee4 --- /dev/null +++ b/csrc/ops_hip.cuh @@ -0,0 +1,213 @@ +// !!! This is a file automatically generated by hipify!!! +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define CUDA_CHECK_RETURN(value) \ + { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } \ + } + +#define CHECK_HIPSPARSE(value) \ + { \ + hipsparseStatus_t _m_hipStat = value; \ + if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf( \ + stderr, "Error %s at line %d in file %s\n", hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__ \ + ); \ + exit(1); \ + } \ + } + +inline void checkHipStatus(hipError_t status) { + if (status != hipSuccess) { + printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status)); + throw std::logic_error("hip API failed"); + } +} + +inline int checkHipblasStatus(hipblasStatus_t status) { + if (status != HIPBLAS_STATUS_SUCCESS) { + printf("hipBLAS API failed with status %d\n", status); + // throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t { + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t { + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, + ADEMAMIX = 6, +} Optimizer_t; + +typedef enum Transform_t { + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t { + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context { + public: + rocblas_handle m_handle; + + Context() { + rocblas_handle handle; + rocblas_create_handle(&handle); + m_handle = handle; + } +}; + +class ContextLt { + public: + hipblasLtHandle_t m_handle; + + ContextLt() { + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); + m_handle = handle; + } +}; + +class ContextHipsparse { + public: + hipsparseHandle_t m_handle; + + ContextHipsparse() { + hipsparseHandle_t handle; + hipsparseCreate(&handle); + m_handle = handle; + } +}; + +void quantize(float* code, float* A, unsigned char* out, int n); +void dequantize(float* code, unsigned char* A, float* out, int n, hipStream_t stream); +template +void quantizeBlockwise( + float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n +); +template +void dequantizeBlockwise( + float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, hipStream_t stream +); + +template +void optimizer32bit( + T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2, + float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, + bool skip_zeros, int n +); + +template +void optimizerStatic8bit( + T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, + float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n +); + +template +void optimizerStatic8bitBlockwise( + T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, + float weight_decay, const float gnorm_scale, bool skip_zeros, int n +); + +template void percentileClipping(T* g, float* gnorm_vec, int step, const int n); + +void gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc +); +void strided_gemmex( + Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, + int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount +); + +template +int igemmlt( + hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale, + int lda, int ldb, int ldc, hipStream_t stream +); + +void cutlass_igemm( + bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc +); +void dequant_mm_int32_fp16( + int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream +); +void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, hipStream_t stream); +void int8VectorQuant( + half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream +); + +void spmm_coo( + hipsparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, + int ldb, half* B, int ldc, half* C, bool transposed_B +); + +template +void spmm_coo_very_sparse_naive( + int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out, + float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB +); + +void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits); +template +void gemm_4bit_inference( + int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize +); +template +void gemm_4bit_inference_naive( + int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc, + int blocksize, hipStream_t stream +); + +template void func(T* A, T* B, T value, long n); + +#endif diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 63f46a20c..9c4cab9cc 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -6,11 +6,29 @@ #if BUILD_CUDA #include #endif +#if BUILD_HIP +#include +#endif #if BUILD_MPS // #include #endif #include +// Compatibility between HIP/CUDA APIs +#if BUILD_HIP +#define cudaStream_t hipStream_t +#define __nv_bfloat16 hip_bfloat16 +#define cublasLtHandle_t hipblasLtHandle_t +#define ContextCusparse ContextHipsparse +#define cusparseHandle_t hipsparseHandle_t +#define cudaMallocManaged hipMallocManaged +#define cudaMemAttachHost hipMemAttachHost +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess +#define cudaMemPrefetchAsync hipMemPrefetchAsync +#endif + // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // maintain all that boilerplate @@ -18,7 +36,7 @@ // UNMANGLED CALLS //=================================================================================== -#if BUILD_CUDA +#if BUILD_CUDA || BUILD_HIP // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } @@ -291,7 +309,7 @@ void spmm_coo_very_sparse_naive_int8( #endif extern "C" { -#if BUILD_CUDA +#if BUILD_CUDA || BUILD_HIP void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { diff --git a/tests/helpers.py b/tests/helpers.py index 02613bb75..a87bc5d08 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,8 @@ import torch +from bitsandbytes.cextension import HIP_ENVIRONMENT + test_dims_rng = random.Random(42) @@ -21,7 +23,7 @@ def get_available_devices(): # If the environment variable is set, use it directly. return [os.environ["BNB_TEST_DEVICE"]] - devices = ["cpu"] + devices = [] if HIP_ENVIRONMENT else ["cpu"] if hasattr(torch, "accelerator"): # PyTorch 2.6+ - determine accelerator using agnostic API. diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 79406472e..3d8b688ee 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,6 +1,6 @@ import pytest -from bitsandbytes.cextension import get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs @@ -13,11 +13,13 @@ def cuda120_spec() -> CUDASpecs: ) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" diff --git a/tests/test_functional.py b/tests/test_functional.py index 4fb0a0d2f..b84db6502 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -9,6 +9,7 @@ import bitsandbytes as bnb from bitsandbytes import functional as F +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -92,7 +93,10 @@ class Test8BitBlockwiseQuantizeFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) - @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) + @pytest.mark.parametrize( + "blocksize", + [4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128], + ) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 @@ -823,6 +827,7 @@ def test_coo_int8_vectorwise_quant(self, device, dim1, dim2): torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) +@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") class TestSpMMFunctional: @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @@ -1100,7 +1105,10 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize( + "blocksize", + [64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096], + ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): pytest.skip("This configuration is not supported on HPU.") @@ -1135,7 +1143,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1201,6 +1209,9 @@ def test_bench_4bit_dequant(self, quant_type): # torch.cuda.synchronize() # print((time.time()-t0)/iters*1e6) + @pytest.mark.skipif( + HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64" + ) @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @@ -1361,6 +1372,10 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) + @pytest.mark.skipif( + HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a", + reason="this test is not supported on ROCm with gfx90a architecture yet", + ) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3): pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3") diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 9fcde695d..e07b54d2d 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -8,6 +8,7 @@ import torch import bitsandbytes as bnb +from bitsandbytes.cextension import HIP_ENVIRONMENT from tests.helpers import ( TRUE_FALSE, describe_dtype, @@ -191,7 +192,7 @@ def test_linear_serialization( @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -213,7 +214,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -242,7 +243,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128]) +@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): diff --git a/tests/test_ops.py b/tests/test_ops.py index 52f26fb05..8aa0560fd 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -4,6 +4,7 @@ import torch import bitsandbytes +from bitsandbytes.cextension import HIP_ENVIRONMENT from bitsandbytes.functional import ipex_xpu from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu @@ -102,7 +103,7 @@ def test_int8_scaled_mm(self, device, dtype, has_bias): class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": if dtype != torch.float32: @@ -126,7 +127,7 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): if device == "cpu" and dtype != torch.float32: pytest.skip("CPU implementation is only available for float32") @@ -156,7 +157,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -180,7 +181,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -214,7 +215,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.")