|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import ctypes as ct |
| 4 | +from typing import Sequence, Tuple |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from ..._ops import register_kernel |
| 9 | +from ...cextension import lib |
| 10 | +_ALLOWED_BLOCKS = (64, 128, 256, 512, 1024, 2048, 4096) |
| 11 | +_SUPPORTED_DTYPES = (torch.float16, torch.float32) |
| 12 | + |
| 13 | + |
| 14 | +lib.cquantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] |
| 15 | +lib.cquantize_blockwise_fp16_nf4_tensor.restype = None |
| 16 | +lib.cquantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] |
| 17 | +lib.cquantize_blockwise_fp32_nf4_tensor.restype = None |
| 18 | +lib.cdequantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] |
| 19 | +lib.cdequantize_blockwise_fp16_nf4_tensor.restype = None |
| 20 | +lib.cdequantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32] |
| 21 | +lib.cdequantize_blockwise_fp32_nf4_tensor.restype = None |
| 22 | + |
| 23 | + |
| 24 | +def _quantize_nf4( |
| 25 | + A: torch.Tensor, blocksize: int, quant_storage: torch.dtype |
| 26 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 27 | + torch._check(blocksize in _ALLOWED_BLOCKS) |
| 28 | + torch._check(quant_storage == torch.uint8, lambda: "Only uint8 storage is supported for NF4 on MPS.") |
| 29 | + |
| 30 | + A = A.contiguous() |
| 31 | + n = A.numel() |
| 32 | + blocks = -(n // -blocksize) |
| 33 | + |
| 34 | + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) |
| 35 | + out = torch.empty(((n + 1) // 2, 1), device=A.device, dtype=quant_storage) |
| 36 | + |
| 37 | + if A.dtype == torch.float16: |
| 38 | + lib.cquantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) |
| 39 | + elif A.dtype == torch.float32: |
| 40 | + lib.cquantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) |
| 41 | + else: |
| 42 | + torch._check(False, lambda: f"NF4 quantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {A.dtype}") |
| 43 | + |
| 44 | + return out, absmax |
| 45 | + |
| 46 | + |
| 47 | +def _dequantize_nf4( |
| 48 | + A: torch.Tensor, |
| 49 | + absmax: torch.Tensor, |
| 50 | + blocksize: int, |
| 51 | + dtype: torch.dtype, |
| 52 | + out: torch.Tensor, |
| 53 | +) -> None: |
| 54 | + torch._check(blocksize in _ALLOWED_BLOCKS) |
| 55 | + |
| 56 | + A = A.contiguous() |
| 57 | + absmax = absmax.contiguous() |
| 58 | + torch._check(out.is_contiguous(), lambda: "Output tensor must be contiguous for NF4 dequantization on MPS.") |
| 59 | + |
| 60 | + if dtype == torch.float16: |
| 61 | + lib.cdequantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) |
| 62 | + elif dtype == torch.float32: |
| 63 | + lib.cdequantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize)) |
| 64 | + else: |
| 65 | + torch._check(False, lambda: f"NF4 dequantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {dtype}") |
| 66 | + |
| 67 | + |
| 68 | +@register_kernel("bitsandbytes::quantize_4bit", "mps") |
| 69 | +def _( |
| 70 | + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype |
| 71 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 72 | + if quant_type != "nf4" or A.dtype not in _SUPPORTED_DTYPES: |
| 73 | + return torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, quant_storage) |
| 74 | + return _quantize_nf4(A, blocksize, quant_storage) |
| 75 | + |
| 76 | + |
| 77 | +@register_kernel("bitsandbytes::dequantize_4bit", "mps") |
| 78 | +def _( |
| 79 | + A: torch.Tensor, |
| 80 | + absmax: torch.Tensor, |
| 81 | + blocksize: int, |
| 82 | + quant_type: str, |
| 83 | + shape: Sequence[int], |
| 84 | + dtype: torch.dtype, |
| 85 | +) -> torch.Tensor: |
| 86 | + if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES: |
| 87 | + return torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype) |
| 88 | + out = torch.empty(shape, dtype=dtype, device=A.device) |
| 89 | + _dequantize_nf4(A, absmax, blocksize, dtype, out) |
| 90 | + return out |
| 91 | + |
| 92 | + |
| 93 | +@register_kernel("bitsandbytes::dequantize_4bit.out", "mps") |
| 94 | +def _( |
| 95 | + A: torch.Tensor, |
| 96 | + absmax: torch.Tensor, |
| 97 | + blocksize: int, |
| 98 | + quant_type: str, |
| 99 | + shape: Sequence[int], |
| 100 | + dtype: torch.dtype, |
| 101 | + out: torch.Tensor, |
| 102 | +) -> None: |
| 103 | + if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES: |
| 104 | + torch.ops.bitsandbytes.dequantize_4bit.out.default( |
| 105 | + A, |
| 106 | + absmax, |
| 107 | + blocksize, |
| 108 | + quant_type, |
| 109 | + shape, |
| 110 | + dtype, |
| 111 | + out, |
| 112 | + ) |
| 113 | + return |
| 114 | + |
| 115 | + torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}") |
| 116 | + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") |
| 117 | + _dequantize_nf4(A, absmax, blocksize, dtype, out) |
0 commit comments