|
| 1 | +import ctypes as ct |
1 | 2 | from typing import Literal, Optional, Tuple, Union |
2 | 3 |
|
3 | 4 | import torch |
4 | 5 |
|
5 | | -from bitsandbytes.utils import QuantState |
6 | | - |
7 | | -from .base import Backend |
8 | | - |
9 | 6 | try: |
10 | 7 | # to support Ascend NPU backend |
11 | 8 | import torch_npu # noqa: F401 |
12 | 9 | except ImportError: |
13 | 10 | pass |
14 | 11 |
|
| 12 | +from bitsandbytes.cextension import lib |
| 13 | +from bitsandbytes.functional import ( |
| 14 | + get_4bit_type, |
| 15 | + get_ptr, |
| 16 | +) |
| 17 | +from bitsandbytes.utils import QuantState |
| 18 | + |
| 19 | +from .base import Backend |
| 20 | + |
| 21 | + |
| 22 | +def assert_on_npu(tensors): |
| 23 | + if not all(t.device.type == "npu" for t in tensors if t is not None): |
| 24 | + raise TypeError( |
| 25 | + "All input tensors to be on NPU, but found some tensors not be on NPU:\n" |
| 26 | + f"{[(t.shape, t.device) if isinstance(t, torch.Tensor) else None for t in tensors]}" |
| 27 | + ) |
| 28 | + return True |
| 29 | + |
15 | 30 |
|
16 | 31 | class NPUBackend(Backend): |
17 | 32 | def double_quant( |
@@ -75,23 +90,140 @@ def quantize_4bit( |
75 | 90 | A: torch.Tensor, |
76 | 91 | absmax: Optional[torch.Tensor] = None, |
77 | 92 | out: Optional[torch.Tensor] = None, |
78 | | - blocksize=64, |
| 93 | + blocksize: Optional[int] = None, |
79 | 94 | compress_statistics=False, |
80 | | - quant_type: Literal["fp4", "nf4"] = "fp4", |
| 95 | + quant_type: Literal["fp4", "nf4"] = "nf4", |
81 | 96 | quant_storage=torch.uint8, |
82 | 97 | ) -> Tuple[torch.Tensor, QuantState]: |
83 | | - raise NotImplementedError |
| 98 | + if quant_type not in ["nf4"]: |
| 99 | + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") |
| 100 | + if compress_statistics: |
| 101 | + raise NotImplementedError("compress_statistics is not implemented.") |
| 102 | + if blocksize is None: |
| 103 | + blocksize = 128 |
| 104 | + |
| 105 | + prev_device = torch.npu.current_device() |
| 106 | + torch.npu.set_device(A.device) |
| 107 | + if A.dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| 108 | + data = [ |
| 109 | + -1.0, |
| 110 | + -0.6961928009986877, |
| 111 | + -0.5250730514526367, |
| 112 | + -0.39491748809814453, |
| 113 | + -0.28444138169288635, |
| 114 | + -0.18477343022823334, |
| 115 | + -0.09105003625154495, |
| 116 | + 0.0, |
| 117 | + 0.07958029955625534, |
| 118 | + 0.16093020141124725, |
| 119 | + 0.24611230194568634, |
| 120 | + 0.33791524171829224, |
| 121 | + 0.44070982933044434, |
| 122 | + 0.5626170039176941, |
| 123 | + 0.7229568362236023, |
| 124 | + 1.0, |
| 125 | + ] |
| 126 | + data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1) |
| 127 | + absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values |
| 128 | + a = A.view(-1, blocksize) / absmax.float() |
| 129 | + diff = torch.abs(a.unsqueeze(-1) - data) |
| 130 | + out = (torch.argmin(diff, dim=-1) + 8) % 16 |
| 131 | + out = out.reshape(-1, 2) |
| 132 | + out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) |
| 133 | + else: |
| 134 | + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") |
| 135 | + assert_on_npu([A, absmax, out]) |
| 136 | + torch.npu.set_device(prev_device) |
| 137 | + |
| 138 | + code = get_4bit_type(quant_type, device=A.device) |
| 139 | + state = QuantState( |
| 140 | + absmax=absmax, |
| 141 | + shape=A.shape, |
| 142 | + dtype=A.dtype, |
| 143 | + blocksize=blocksize, |
| 144 | + code=code, |
| 145 | + quant_type=quant_type, |
| 146 | + ) |
| 147 | + |
| 148 | + return out, state |
84 | 149 |
|
85 | 150 | def dequantize_4bit( |
86 | 151 | self, |
87 | 152 | A: torch.Tensor, |
88 | 153 | quant_state: Optional[QuantState] = None, |
89 | 154 | absmax: Optional[torch.Tensor] = None, |
90 | 155 | out: Optional[torch.Tensor] = None, |
91 | | - blocksize: int = 64, |
92 | | - quant_type: Literal["fp4", "nf4"] = "fp4", |
| 156 | + blocksize: Optional[int] = None, |
| 157 | + quant_type: Literal["fp4", "nf4"] = "nf4", |
93 | 158 | ) -> torch.Tensor: |
94 | | - raise NotImplementedError |
| 159 | + if blocksize is None: |
| 160 | + blocksize = 128 |
| 161 | + supported_blocksizes = [2048, 4096, 1024, 512, 256, 128, 64] |
| 162 | + if blocksize not in supported_blocksizes: |
| 163 | + raise ValueError( |
| 164 | + f"The blockwise of {blocksize} is not supported. Supported values: {supported_blocksizes}" |
| 165 | + ) |
| 166 | + |
| 167 | + if quant_state is None: |
| 168 | + assert absmax is not None and out is not None |
| 169 | + quant_state = QuantState( |
| 170 | + absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type |
| 171 | + ) |
| 172 | + else: |
| 173 | + absmax = quant_state.absmax |
| 174 | + |
| 175 | + if out is None: |
| 176 | + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) |
| 177 | + |
| 178 | + n = out.numel() |
| 179 | + |
| 180 | + prev_device = torch.npu.current_device() |
| 181 | + torch.npu.set_device(A.device) |
| 182 | + assert_on_npu([A, absmax, out]) |
| 183 | + |
| 184 | + if quant_state.quant_type not in ["nf4"]: |
| 185 | + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") |
| 186 | + |
| 187 | + if out.dtype == torch.float32: |
| 188 | + lib.cdequantize_blockwise_fp32_nf4( |
| 189 | + get_ptr(A), |
| 190 | + get_ptr(absmax), |
| 191 | + get_ptr(out), |
| 192 | + ct.c_int(quant_state.blocksize), |
| 193 | + ct.c_int(n), |
| 194 | + torch.npu.current_stream(), |
| 195 | + ) |
| 196 | + elif out.dtype == torch.float16: |
| 197 | + lib.cdequantize_blockwise_fp16_nf4( |
| 198 | + get_ptr(A), |
| 199 | + get_ptr(absmax), |
| 200 | + get_ptr(out), |
| 201 | + ct.c_int(quant_state.blocksize), |
| 202 | + ct.c_int(n), |
| 203 | + torch.npu.current_stream(), |
| 204 | + ) |
| 205 | + elif out.dtype == torch.bfloat16: |
| 206 | + # bf16: bf16 -> fp32 -> op -> fp32 -> bf16 |
| 207 | + absmax = absmax.to(torch.float32) |
| 208 | + out = out.to(torch.float32) |
| 209 | + lib.cdequantize_blockwise_fp32_nf4( |
| 210 | + get_ptr(A), |
| 211 | + get_ptr(absmax), |
| 212 | + get_ptr(out), |
| 213 | + ct.c_int(quant_state.blocksize), |
| 214 | + ct.c_int(n), |
| 215 | + torch.npu.current_stream(), |
| 216 | + ) |
| 217 | + out = out.to(torch.bfloat16) |
| 218 | + else: |
| 219 | + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") |
| 220 | + torch.npu.set_device(prev_device) |
| 221 | + is_transposed = True if A.shape[0] == 1 else False |
| 222 | + |
| 223 | + if is_transposed: |
| 224 | + return out.t() |
| 225 | + else: |
| 226 | + return out |
95 | 227 |
|
96 | 228 | def gemv_4bit( |
97 | 229 | self, |
|
0 commit comments