@@ -147,6 +147,188 @@ def make_quant(module, names, bits, groupsize, name=''):
147147 for name1 , child in module .named_children ():
148148 make_quant (child , names , bits , groupsize , name + '.' + name1 if name != '' else name1 )
149149
150+ def make_quant_custom (module , names , bits , groupsize , name = '' ):
151+ if isinstance (module , QuantLinear ):
152+ return
153+ for attr in dir (module ):
154+ tmp = getattr (module , attr )
155+ name1 = name + '.' + attr if name != '' else attr
156+ if name1 in names :
157+
158+ bias_name = attr .replace ('w' , 'b' )
159+ layer_name = attr .replace ('w' , 'quant' )
160+ setattr (module , layer_name , QuantLinear_custom (bits , groupsize , tmp .shape [0 ], tmp .shape [1 ], module .w [bias_name ] is not None ))
161+
162+
163+ class QuantLinear_custom (nn .Module ):
164+ def __init__ (self , bits , groupsize , infeatures , outfeatures , bias , kernel_switch_threshold = 128 , is_cuda = is_cuda ):
165+ super ().__init__ ()
166+ if bits not in [2 ,3 ,4 ,8 ]:
167+ raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
168+ self .infeatures = infeatures
169+ self .outfeatures = outfeatures
170+ self .bits = bits
171+ self .groupsize = groupsize if groupsize != - 1 else infeatures
172+ self .maxq = 2 ** self .bits - 1
173+
174+ self .register_buffer ('qweight' , torch .zeros ((infeatures // 32 * self .bits , outfeatures ), dtype = torch .int32 ))
175+ self .register_buffer ('qzeros' , torch .zeros ((math .ceil (infeatures / self .groupsize ), outfeatures // 32 * self .bits ), dtype = torch .int32 ))
176+ self .register_buffer ('scales' , torch .zeros ((math .ceil (infeatures / self .groupsize ), outfeatures ), dtype = torch .float16 ))
177+ self .register_buffer ('g_idx' , torch .tensor ([i // self .groupsize for i in range (infeatures )], dtype = torch .int32 ))
178+ if bias :
179+ self .register_buffer ('bias' , torch .zeros ((outfeatures ),dtype = torch .float16 ))
180+ else :
181+ self .bias = None
182+
183+ # is performed by unpacking the weights and using torch.matmul
184+ if self .bits in [2 ,4 ,8 ]:
185+ self .register_buffer ('wf' ,torch .tensor (list (range (0 ,32 ,self .bits )), dtype = torch .int32 ).unsqueeze (0 ),persistent = False )
186+ elif self .bits == 3 :
187+ self .register_buffer ('wf' , torch .tensor ([[0 , 3 , 6 , 9 , 12 , 15 , 18 , 21 , 24 , 27 , 30 , 0 ],
188+ [0 , 1 , 4 , 7 , 10 , 13 , 16 , 19 , 22 , 25 , 28 , 31 ],
189+ [0 , 2 , 5 , 8 , 11 , 14 , 17 , 20 , 23 , 26 , 29 , 0 ],], dtype = torch .int32 ).reshape (1 ,3 ,12 ), persistent = False )
190+
191+ self .kernel_switch_threshold = kernel_switch_threshold
192+ self .is_cuda = is_cuda
193+
194+ def pack (self , weight , bias , scales , zeros , g_idx = None ):
195+ self .g_idx = g_idx .clone () if g_idx is not None else self .g_idx
196+
197+ scales = scales .t ().contiguous ()
198+ zeros = zeros .t ().contiguous ()
199+ scale_zeros = zeros * scales
200+ self .scales = scales .clone ().half ()
201+ if bias is not None :
202+ self .bias = bias .clone ().half ()
203+
204+ intweight = []
205+ for idx in range (self .infeatures ):
206+ intweight .append (torch .round ((weight [:,idx ] + scale_zeros [self .g_idx [idx ]]) / self .scales [self .g_idx [idx ]]).to (torch .int )[:,None ])
207+ intweight = torch .cat (intweight ,dim = 1 )
208+ intweight = intweight .t ().contiguous ()
209+ intweight = intweight .numpy ().astype (np .uint32 )
210+ qweight = np .zeros (
211+ (intweight .shape [0 ] // 32 * self .bits , intweight .shape [1 ]), dtype = np .uint32
212+ )
213+ i = 0
214+ row = 0
215+ while row < qweight .shape [0 ]:
216+ if self .bits in [2 ,4 ,8 ]:
217+ for j in range (i , i + (32 // self .bits )):
218+ qweight [row ] |= intweight [j ] << (self .bits * (j - i ))
219+ i += 32 // self .bits
220+ row += 1
221+ elif self .bits == 3 :
222+ for j in range (i , i + 10 ):
223+ qweight [row ] |= intweight [j ] << (3 * (j - i ))
224+ i += 10
225+ qweight [row ] |= intweight [i ] << 30
226+ row += 1
227+ qweight [row ] |= (intweight [i ] >> 2 ) & 1
228+ i += 1
229+ for j in range (i , i + 10 ):
230+ qweight [row ] |= intweight [j ] << (3 * (j - i ) + 1 )
231+ i += 10
232+ qweight [row ] |= intweight [i ] << 31
233+ row += 1
234+ qweight [row ] |= (intweight [i ] >> 1 ) & 0x3
235+ i += 1
236+ for j in range (i , i + 10 ):
237+ qweight [row ] |= intweight [j ] << (3 * (j - i ) + 2 )
238+ i += 10
239+ row += 1
240+ else :
241+ raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
242+
243+ qweight = qweight .astype (np .int32 )
244+ self .qweight = torch .from_numpy (qweight )
245+
246+ zeros -= 1
247+ zeros = zeros .numpy ().astype (np .uint32 )
248+ qzeros = np .zeros ((zeros .shape [0 ], zeros .shape [1 ] // 32 * self .bits ), dtype = np .uint32 )
249+ i = 0
250+ col = 0
251+ while col < qzeros .shape [1 ]:
252+ if self .bits in [2 ,4 ,8 ]:
253+ for j in range (i , i + (32 // self .bits )):
254+ qzeros [:, col ] |= zeros [:, j ] << (self .bits * (j - i ))
255+ i += 32 // self .bits
256+ col += 1
257+ elif self .bits == 3 :
258+ for j in range (i , i + 10 ):
259+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ))
260+ i += 10
261+ qzeros [:, col ] |= zeros [:, i ] << 30
262+ col += 1
263+ qzeros [:, col ] |= (zeros [:, i ] >> 2 ) & 1
264+ i += 1
265+ for j in range (i , i + 10 ):
266+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 1 )
267+ i += 10
268+ qzeros [:, col ] |= zeros [:, i ] << 31
269+ col += 1
270+ qzeros [:, col ] |= (zeros [:, i ] >> 1 ) & 0x3
271+ i += 1
272+ for j in range (i , i + 10 ):
273+ qzeros [:, col ] |= zeros [:, j ] << (3 * (j - i ) + 2 )
274+ i += 10
275+ col += 1
276+ else :
277+ raise NotImplementedError ("Only 2,3,4,8 bits are supported." )
278+
279+ qzeros = qzeros .astype (np .int32 )
280+ self .qzeros = torch .from_numpy (qzeros )
281+
282+ def forward (self , x ):
283+ out_shape = x .shape [:- 1 ] + (self .outfeatures , )
284+ x = x .reshape (- 1 ,x .shape [- 1 ])
285+ if self .is_cuda is True and (self .kernel_switch_threshold is False or x .shape [0 ] < self .kernel_switch_threshold ):
286+ out = torch .zeros ((x .shape [0 ], self .outfeatures ), device = x .device , dtype = torch .float32 )
287+ if self .bits == 2 :
288+ quant_cuda .vecquant2matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
289+ elif self .bits == 3 :
290+ quant_cuda .vecquant3matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
291+ elif self .bits == 4 :
292+ quant_cuda .vecquant4matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
293+ elif self .bits == 8 :
294+ quant_cuda .vecquant8matmul (x .float (), self .qweight , out , self .scales .float (), self .qzeros , self .g_idx )
295+ out = out .half ()
296+ else :
297+ if self .bits in [2 ,4 ,8 ]:
298+ zeros = torch .bitwise_right_shift (torch .unsqueeze (self .qzeros , 2 ).expand (- 1 , - 1 , 32 // self .bits ), self .wf .unsqueeze (0 )).to (torch .int16 if self .bits == 8 else torch .int8 )
299+ torch .bitwise_and (zeros , (2 ** self .bits ) - 1 , out = zeros )
300+
301+ zeros = zeros + 1
302+ zeros = zeros .reshape (self .scales .shape )
303+
304+ weight = torch .bitwise_right_shift (torch .unsqueeze (self .qweight , 1 ).expand (- 1 , 32 // self .bits , - 1 ), self .wf .unsqueeze (- 1 )).to (torch .int16 if self .bits == 8 else torch .int8 )
305+ torch .bitwise_and (weight ,(2 ** self .bits ) - 1 , out = weight )
306+ elif self .bits == 3 :
307+ zeros = self .qzeros .reshape (self .qzeros .shape [0 ], self .qzeros .shape [1 ]// 3 , 3 , 1 ).expand (- 1 , - 1 , - 1 , 12 )
308+ zeros = (zeros >> self .wf .unsqueeze (0 ))
309+ zeros [:,:,0 ,10 ] = (zeros [:,:,0 ,10 ]& 0x3 ) | ((zeros [:,:,1 ,0 ] << 2 )& 0x4 )
310+ zeros [:,:,1 ,11 ] = (zeros [:,:,1 ,11 ]& 0x1 ) | ((zeros [:,:,2 ,0 ] << 1 )& 0x6 )
311+ zeros = zeros & 0x7
312+ zeros = torch .cat ([zeros [:,:,0 ,:11 ], zeros [:,:,1 ,1 :12 ], zeros [:,:,2 ,1 :11 ]], dim = 2 )
313+
314+ zeros = zeros + 1
315+ zeros = zeros .reshape (self .scales .shape )
316+
317+ weight = self .qweight .reshape (self .qweight .shape [0 ]// 3 , 3 , 1 , self .qweight .shape [1 ]).expand (- 1 , - 1 , 12 , - 1 )
318+ weight = (weight >> self .wf .unsqueeze (- 1 ))& 0x7
319+ weight [:,0 ,10 ] = (weight [:,0 ,10 ]& 0x3 ) | ((weight [:,1 ,0 ] << 2 )& 0x4 )
320+ weight [:,1 ,11 ] = (weight [:,1 ,11 ]& 0x1 ) | ((weight [:,2 ,0 ] << 1 )& 0x6 )
321+ weight = weight & 0x7
322+ weight = torch .cat ([weight [:,0 ,:11 ], weight [:,1 ,1 :12 ], weight [:,2 ,1 :11 ]], dim = 1 )
323+
324+ weight = weight .reshape (weight .shape [0 ] * weight .shape [1 ], weight .shape [2 ])
325+
326+ weights = (self .scales [self .g_idx ] * (weight - zeros [self .g_idx ]))
327+ out = torch .matmul (x .half (), weights )
328+ out = out .reshape (out_shape )
329+ out = out + self .bias if self .bias is not None else out
330+ return out
331+
150332class QuantLinear (nn .Module ):
151333 def __init__ (self , bits , groupsize , infeatures , outfeatures , bias , kernel_switch_threshold = 128 , is_cuda = is_cuda ):
152334 super ().__init__ ()
0 commit comments