1111
1212from bitsandbytes .cextension import lib
1313from bitsandbytes .functional import (
14+ COOSparseTensor ,
1415 get_4bit_type ,
1516 get_ptr ,
1617)
@@ -28,6 +29,43 @@ def assert_on_npu(tensors):
2829 return True
2930
3031
32+ def coo_zeros (rows , cols , rowidx , colidx , values , nnz , device , dtype = torch .half ):
33+ rowidx = rowidx .to (torch .int32 )
34+ colidx = colidx .to (torch .int32 )
35+ values = values .to (device ).to (dtype )
36+ return COOSparseTensor (rows , cols , nnz , rowidx , colidx , values )
37+
38+
39+ def row_col_stats (A , threshold ):
40+ cols = A .shape [- 1 ]
41+ if len (A .shape ) == 3 :
42+ rows = A .shape [0 ] * A .shape [1 ]
43+ else :
44+ rows = A .shape [0 ]
45+
46+ row_max = torch .zeros (rows , dtype = torch .float32 , device = "npu" )
47+ col_max = torch .zeros (cols , dtype = torch .float32 , device = "npu" )
48+ outlier_num = torch .zeros (1 , dtype = torch .int32 , device = "npu" )
49+ lib .cget_col_row_stats (
50+ get_ptr (A ),
51+ get_ptr (row_max ),
52+ get_ptr (col_max ),
53+ get_ptr (outlier_num ),
54+ ct .c_float (threshold ),
55+ ct .c_int32 (rows ),
56+ ct .c_int32 (cols ),
57+ torch .npu .current_stream ()
58+ )
59+ return row_max , col_max , outlier_num
60+
61+
62+ class Int8AB :
63+ def __init__ (self , A : torch .Tensor , B : torch .Tensor ):
64+ self .A = A
65+ self .B = B
66+ self .device = A .device
67+
68+
3169class NPUBackend (Backend ):
3270 def int8_double_quant (
3371 self ,
@@ -38,7 +76,53 @@ def int8_double_quant(
3876 out_row : Optional [torch .Tensor ] = None ,
3977 threshold = 0.0 ,
4078 ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
41- raise NotImplementedError
79+ past_device = None
80+ device = A .device
81+ assert A .dtype == torch .half
82+ assert device .type == "npu"
83+
84+ cols = A .shape [- 1 ]
85+ if len (A .shape ) == 3 :
86+ rows = A .shape [0 ] * A .shape [1 ]
87+ else :
88+ rows = A .shape [0 ]
89+
90+ if past_device != str (A .device ):
91+ torch .npu .set_device (A .device ) # reset context
92+ past_device = str (A .device )
93+
94+ row_stats , col_stats , cnt_npu = row_col_stats (A , threshold )
95+
96+ quant_row = torch .empty ((rows , cols ), dtype = torch .int8 , device = device )
97+ quant_col = torch .empty ((rows , cols ), dtype = torch .int8 , device = device )
98+ outliers_row_idx = torch .zeros (rows , dtype = torch .int32 , device = device )
99+ outliers_col_idx = torch .zeros (40 * cols , dtype = torch .int32 , device = device ) - 1
100+ outliers_value = torch .empty (0 , dtype = torch .float16 , device = device )
101+
102+ lib .cdouble_rowcol_quant (
103+ get_ptr (A ),
104+ get_ptr (row_stats ),
105+ get_ptr (col_stats ),
106+ get_ptr (quant_row ),
107+ get_ptr (quant_col ),
108+ get_ptr (outliers_row_idx ),
109+ get_ptr (outliers_col_idx ),
110+ get_ptr (outliers_value ),
111+ ct .c_int (cols ),
112+ ct .c_float (threshold ),
113+ ct .c_int32 (rows ),
114+ ct .c_int32 (cols ),
115+ torch .npu .current_stream ()
116+ )
117+
118+ colidx_tmp = torch .unique (outliers_col_idx )
119+ colidx = colidx_tmp [colidx_tmp != - 1 ]
120+
121+ coo_tensor = None
122+ if threshold != 0.0 :
123+ coo_tensor = coo_zeros (rows , cols , outliers_row_idx , colidx , outliers_value , cnt_npu , device , dtype = torch .half )
124+
125+ return quant_row , quant_col , row_stats , col_stats , coo_tensor
42126
43127 def int8_vectorwise_dequant (self , A , stats ):
44128 return super ().int8_vectorwise_dequant (A , stats )
@@ -48,7 +132,35 @@ def int8_vectorwise_quant(
48132 A : torch .Tensor ,
49133 threshold = 0.0 ,
50134 ) -> Tuple [torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
51- raise NotImplementedError
135+ device = A .device
136+ assert A .dtype == torch .half
137+ assert device .type == "npu"
138+
139+ cols = A .shape [- 1 ]
140+ if len (A .shape ) == 3 :
141+ rows = A .shape [0 ] * A .shape [1 ]
142+ else :
143+ rows = A .shape [0 ]
144+
145+ A_no_threshold = None
146+ if threshold > 0.0 :
147+ zero = torch .tensor (0.0 , dtype = torch .half , device = device )
148+ A_no_threshold = torch .where (A .view (rows , cols ).abs () < threshold , A .view (rows , cols ), zero )
149+ row_stats = torch .amax (A_no_threshold .abs (), dim = 1 , keepdim = True ).to (device )
150+ out_row = torch .round (A_no_threshold * 127.0 / row_stats ).to (torch .int8 )
151+ else :
152+ row_stats = torch .amax (A .view (rows , cols ).abs (), dim = 1 , keepdim = True ).to (device )
153+ out_row = torch .round (A * 127.0 / row_stats ).to (torch .int8 )
154+
155+ outlier_cols = None
156+ if threshold > 0.0 :
157+ # TODO we could improve perf of this
158+ outliers = A .abs () >= threshold
159+
160+ if outliers .any ():
161+ outlier_cols = torch .argwhere (outliers .any (dim = 0 )).view (- 1 )
162+
163+ return out_row , row_stats , outlier_cols
52164
53165 def transform (
54166 self ,
@@ -69,7 +181,7 @@ def int8_linear_matmul(
69181 out : Optional [torch .Tensor ] = None ,
70182 dtype = torch .int32 ,
71183 ) -> torch .Tensor :
72- raise NotImplementedError
184+ return Int8AB ( A , B )
73185
74186 def int8_mm_dequant (
75187 self ,
@@ -79,7 +191,15 @@ def int8_mm_dequant(
79191 out : Optional [torch .Tensor ] = None ,
80192 bias : Optional [torch .Tensor ] = None ,
81193 ) -> torch .Tensor :
82- raise NotImplementedError
194+ A , B = A .A , A .B
195+ out = torch_npu .npu_quant_matmul (
196+ A ,
197+ B .t (),
198+ scale = col_stats .float () / 127.0 ,
199+ pertoken_scale = row_stats .float ().view (- 1 ) / 127.0 ,
200+ output_dtype = torch .float16
201+ )
202+ return out
83203
84204 def extract_outliers (
85205 self ,
@@ -106,6 +226,10 @@ def quantize_4bit(
106226 if blocksize is None :
107227 blocksize = 128
108228
229+ total_blocks = A .numel () // blocksize
230+ chunks = 8 if A .numel () > 2048 * 2048 else 1
231+ chunksize = (total_blocks + chunks - 1 ) // chunks
232+
109233 prev_device = torch .npu .current_device ()
110234 torch .npu .set_device (A .device )
111235 if A .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ]:
@@ -128,12 +252,27 @@ def quantize_4bit(
128252 1.0 ,
129253 ]
130254 data = torch .tensor (data , device = "npu" , dtype = torch .float32 ).view (1 , - 1 )
131- absmax = A .view (- 1 , blocksize ).abs ().max (dim = 1 , keepdim = True ).values
132- a = A .view (- 1 , blocksize ) / absmax .float ()
133- diff = torch .abs (a .unsqueeze (- 1 ) - data )
134- out = (torch .argmin (diff , dim = - 1 ) + 8 ) % 16
135- out = out .reshape (- 1 , 2 )
136- out = (out [:, 0 ] + out [:, 1 ] * 16 ).to (torch .uint8 )
255+ chunks_absmax = []
256+ chunks_out = []
257+
258+ for i in range (chunks ):
259+ start = i * chunksize * blocksize
260+ end = min ((i + 1 ) * chunksize * blocksize , A .numel ())
261+ chunk_data = A .view (- 1 )[start :end ].view (- 1 , blocksize )
262+
263+ absmax = chunk_data .abs ().max (dim = 1 , keepdim = True ).values
264+ chunks_absmax .append (absmax )
265+
266+ a = chunk_data / absmax .float ()
267+ diff = torch .abs (a .unsqueeze (- 1 ) - data )
268+ out = (torch .argmin (diff , dim = - 1 ) + 8 ) % 16
269+
270+ out = out .reshape (- 1 , 2 )
271+ out = (out [:, 0 ] + out [:, 1 ] * 16 ).to (torch .uint8 )
272+ chunks_out .append (out )
273+
274+ absmax = torch .cat (chunks_absmax , dim = 0 )
275+ out = torch .cat (chunks_out , dim = 0 )
137276 else :
138277 raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
139278 assert_on_npu ([A , absmax , out ])
0 commit comments