1+ from typing import Literal , Optional , Tuple , Union
2+
3+ import torch
4+
5+ from bitsandbytes .utils import QuantState
6+
7+ from .base import Backend
8+ from .cpu_xpu_common import (
9+ dequantize_4bit_impl ,
10+ gemm_4bit_impl ,
11+ igemmlt_impl ,
12+ mm_dequant_impl ,
13+ quantize_4bit_impl ,
14+ )
15+
16+ Tensor = torch .Tensor
17+
18+
19+ class HPUBackend (Backend ):
20+ mm_dequant_compute_dtype = torch .bfloat16
21+ mm_dequant_output_dtype = torch .bfloat16
22+
23+ def transform (
24+ self ,
25+ A : torch .Tensor ,
26+ to_order : str ,
27+ from_order = "row" ,
28+ out : Optional [torch .Tensor ] = None ,
29+ transpose = False ,
30+ state : Optional [Tuple [torch .Size , str ]] = None ,
31+ ld = None ,
32+ ):
33+ """
34+ Transform tensor A to to_order. It is originally designed for CUDA.
35+ For HPU, it returns the original tensor if transpose=False.
36+ Otherwise, it returns the transpose of A
37+ """
38+ if transpose :
39+ if out is not None :
40+ out .copy_ (A .T )
41+ else :
42+ out = A .T
43+ else :
44+ if out is not None :
45+ out .copy_ (A )
46+ else :
47+ out = A
48+ return out , state
49+
50+ def igemmlt (
51+ self ,
52+ A : torch .Tensor ,
53+ B : torch .Tensor ,
54+ SA : Tuple [torch .Size , str ],
55+ SB : Tuple [torch .Size , str ],
56+ out : Optional [torch .Tensor ] = None ,
57+ Sout : Optional [Tuple [torch .Size , str ]] = None ,
58+ dtype = torch .int32 ,
59+ ) -> Union [torch .Tensor , Tuple [Optional [Tuple [torch .Tensor , Tuple [torch .Size ,
60+ str ]]]]]:
61+
62+ return igemmlt_impl (A , B , SA , SB , out , Sout , dtype )
63+
64+ def mm_dequant (
65+ self ,
66+ A : torch .Tensor ,
67+ quant_state : Tuple [torch .Size , str ],
68+ row_stats : torch .Tensor ,
69+ col_stats : torch .Tensor ,
70+ out : Optional [torch .Tensor ] = None ,
71+ new_row_stats : Optional [torch .Tensor ] = None ,
72+ new_col_stats : Optional [torch .Tensor ] = None ,
73+ bias : Optional [torch .Tensor ] = None ,
74+ ) -> torch .Tensor :
75+
76+ return mm_dequant_impl (
77+ A ,
78+ quant_state ,
79+ row_stats ,
80+ col_stats ,
81+ out ,
82+ new_row_stats ,
83+ new_col_stats ,
84+ bias ,
85+ self .mm_dequant_compute_dtype ,
86+ self .mm_dequant_output_dtype ,
87+ )
88+
89+ def extract_outliers (
90+ self ,
91+ A : torch .Tensor ,
92+ SA : Tuple [torch .Size , str ],
93+ idx : torch .Tensor ,
94+ ) -> torch .Tensor :
95+ """
96+ Extract columns of A by idx
97+ """
98+
99+ return A [:, idx ].contiguous ()
100+
101+ def quantize_4bit (
102+ self ,
103+ A : torch .Tensor ,
104+ absmax : Optional [torch .Tensor ] = None ,
105+ out : Optional [torch .Tensor ] = None ,
106+ blocksize = 64 ,
107+ compress_statistics = False ,
108+ quant_type : Literal ["fp4" , "nf4" ] = "fp4" ,
109+ quant_storage = torch .uint8 ,
110+ ) -> Tuple [torch .Tensor , QuantState ]:
111+
112+ if blocksize is None :
113+ blocksize = 64
114+ assert quant_storage == torch .uint8
115+ return quantize_4bit_impl (
116+ A , absmax , out , blocksize , compress_statistics , quant_type )
117+
118+ def dequantize_4bit (
119+ self ,
120+ A : torch .Tensor ,
121+ quant_state : Optional [QuantState ] = None ,
122+ absmax : Optional [torch .Tensor ] = None ,
123+ out : Optional [torch .Tensor ] = None ,
124+ blocksize : int = 64 ,
125+ quant_type : Literal ["fp4" , "nf4" ] = "fp4" ,
126+ ) -> torch .Tensor :
127+
128+ if blocksize is None :
129+ blocksize = 64
130+ return dequantize_4bit_impl (A , quant_state , absmax , out , blocksize , quant_type )
131+
132+ def gemv_4bit (
133+ self ,
134+ A : torch .Tensor ,
135+ B : torch .Tensor ,
136+ out : Optional [torch .Tensor ] = None ,
137+ transposed_A = False ,
138+ transposed_B = False ,
139+ state : QuantState = None ,
140+ ) -> torch .Tensor :
141+
142+ if state is None :
143+ raise ValueError (
144+ "state cannot be None. gemv_4bit() requires the state from quantize_4bit()"
145+ )
146+
147+ return gemm_4bit_impl (A , B , out , transposed_A , transposed_B , state )
0 commit comments