11import math
22from typing import Literal , Optional , Tuple
3- import warnings
3+
44import torch
55
6+ from bitsandbytes .functional import get_4bit_type
67from bitsandbytes .utils import QuantState
78
89from .base import Backend
910from .cpu_xpu_common import (
10- double_quant_impl ,
11- dequant_8bit ,
12- NF4_QUANT_TABLE ,
1311 INT8_QUANT_TABLE ,
14- )
15- from bitsandbytes .functional import (
16- QuantState ,
17- get_4bit_type ,
12+ NF4_QUANT_TABLE ,
13+ dequant_8bit ,
1814)
1915
2016Tensor = torch .Tensor
2117
18+
2219def assert_on_hpu (tensors ):
2320 on_hpu = True
2421 for t in tensors :
@@ -32,8 +29,8 @@ def assert_on_hpu(tensors):
3229 )
3330 return on_hpu
3431
35- class HPUBackend (Backend ):
3632
33+ class HPUBackend (Backend ):
3734 def int8_double_quant (
3835 self ,
3936 A : torch .Tensor ,
@@ -43,8 +40,7 @@ def int8_double_quant(
4340 out_row : Optional [torch .Tensor ] = None ,
4441 threshold = 0.0 ,
4542 ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
46- assert_on_hpu ([A , col_stats , row_stats , out_col , out_row ])
47- return double_quant_impl (A , col_stats , row_stats , out_col , out_row , threshold )
43+ raise NotImplementedError ("Not yet implemented for HPU backend" )
4844
4945 def transform (
5046 self ,
@@ -100,7 +96,7 @@ def quantize_4bit(
10096 assert_on_hpu ([A , absmax , out ])
10197 assert quant_storage == torch .uint8 , "HPU backend only supports uint8 quant_storage"
10298 return self .quantize_4bit_impl (A , absmax , out , blocksize , compress_statistics , quant_type )
103-
99+
104100 def quantize_4bit_impl (
105101 self ,
106102 A : Tensor ,
@@ -159,10 +155,9 @@ def quantize_4bit_impl(
159155 code = get_4bit_type (quant_type , device = A .device )
160156
161157 if compress_statistics :
162- raise AssertionError ("Double quantization is not supported for HPU backend" )
163158 offset = absmax .mean ()
164159 absmax -= offset
165- qabsmax , state2 = self .hpu_quantize_4bit_impl (absmax , blocksize = 256 , quant_type = "int8" )
160+ qabsmax , state2 = self .quantize_4bit_impl (absmax , blocksize = 256 , quant_type = "int8" )
166161 del absmax
167162 state = QuantState (
168163 absmax = qabsmax ,
@@ -196,10 +191,10 @@ def dequantize_nf4_impl(
196191 HPU dequantization function for NF4 quantized tensors.
197192 """
198193 assert_on_hpu ([input , absmax ])
199- out_shape = (math .prod (quant_state .shape ), )
200- out_dq = torch .ops .hpu .dequantize_nf4 (input , absmax , blocksize ,
201- out_shape = out_shape ,
202- out_dtype = quant_state . dtype )
194+ out_shape = (math .prod (quant_state .shape ),)
195+ out_dq = torch .ops .hpu .dequantize_nf4 (
196+ input , absmax , blocksize , out_shape = out_shape , out_dtype = quant_state . dtype
197+ )
203198 output = out_dq .reshape (quant_state .shape ).T
204199 return output
205200
@@ -214,10 +209,9 @@ def dequantize_4bit(
214209 ) -> torch .Tensor :
215210 if blocksize is None :
216211 blocksize = 64
217-
212+
218213 assert_on_hpu ([A , absmax , out ])
219214 if quant_state .nested :
220- raise AssertionError ("Double quantization is not supported for HPU backend" )
221215 absmax = dequant_8bit (absmax , quant_state .offset , quant_state .state2 )
222216 return self .dequantize_nf4_impl (A , absmax , blocksize , quant_state )
223217
@@ -230,18 +224,7 @@ def gemv_4bit(
230224 transposed_B = False ,
231225 state : QuantState = None ,
232226 ) -> torch .Tensor :
233- assert_on_hpu ([A , B , out ])
234- if state is None :
235- raise ValueError (
236- "state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
237- )
238- dqB = self .dequantize_nf4_impl (B , state .absmax , state .blocksize , state )
239- output = torch .matmul (A , dqB .to (A .dtype ))
240- if out is not None :
241- out .copy_ (output )
242- else :
243- out = output
244- return out
227+ raise NotImplementedError ("Not yet implemented for HPU backend" )
245228
246229 def int8_vectorwise_dequant (self , A : torch .Tensor , stats : torch .Tensor ):
247230 raise NotImplementedError ("Not yet implemented for HPU backend" )
0 commit comments