1313# limitations under the License.
1414
1515import math
16+
1617import torch
1718import torch .nn as nn
19+
1820from auto_round .utils import convert_dtype_torch2str , logger
1921
2022try :
2123 import auto_round_kernel as ark
24+
2225 ARK_INSTALLED = True
2326except :
2427 ARK_INSTALLED = False
3134
3235AWQ_REVERSE_ORDER = [0 , 4 , 1 , 5 , 2 , 6 , 3 , 7 ]
3336
37+
3438def unpack_awq (qweight : torch .Tensor , qzeros : torch .Tensor , bits : int ):
3539 shifts = torch .arange (0 , 32 , bits , device = "cpu" )
3640
@@ -51,6 +55,7 @@ def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
5155
5256 return iweights , izeros
5357
58+
5459def reverse_awq_order (iweights : torch .Tensor , izeros : torch .Tensor , bits : int ):
5560 reverse_order_tensor = torch .arange (
5661 iweights .shape [- 1 ],
@@ -66,14 +71,13 @@ def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
6671 iweights = iweights [:, reverse_order_tensor ]
6772 return iweights , izeros
6873
74+
6975class QuantLinearAWQ (nn .Module ):
7076 QUANT_TYPE = "ark_awq"
7177
7278 def __init__ (self , w_bit , group_size , in_features , out_features , bias , zero_point , dev ):
7379 super ().__init__ ()
74- assert (
75- ARK_INSTALLED
76- ), "Please install auto_round_kernel package."
80+ assert ARK_INSTALLED , "Please install auto_round_kernel package."
7781
7882 self .use_bf16 = ark .check_isa_supported ("AMX" )
7983
@@ -162,9 +166,7 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze
162166
163167 @torch .no_grad ()
164168 def forward (self , x ):
165- assert ARK_INSTALLED , (
166- "ARK kernels could not be loaded. "
167- )
169+ assert ARK_INSTALLED , "ARK kernels could not be loaded. "
168170
169171 input_dtype = x .dtype
170172 out_shape = x .shape [:- 1 ] + (self .out_features ,)
@@ -200,6 +202,7 @@ def extra_repr(self) -> str:
200202 self .group_size ,
201203 )
202204
205+
203206class QuantLinear (nn .Module ):
204207 QUANT_TYPE = "ark_gptq_nozp"
205208 ZP_BIAS = 0
@@ -220,9 +223,7 @@ def __init__(
220223
221224 if bits not in [2 , 4 , 8 ]:
222225 raise NotImplementedError ("Only 2, 4,8 bits are supported for ARK." )
223- assert (
224- ARK_INSTALLED
225- ), "Please install auto_round_kernel."
226+ assert ARK_INSTALLED , "Please install auto_round_kernel."
226227
227228 self .infeatures = infeatures
228229 self .outfeatures = outfeatures
@@ -261,9 +262,8 @@ def __init__(
261262 self .kernel_switch_threshold = kernel_switch_threshold
262263 self .trainable = trainable
263264
264-
265265 def post_init (self ):
266- assert self .qweight .device .type in ["cpu" , ' xpu' ]
266+ assert self .qweight .device .type in ["cpu" , " xpu" ]
267267 # intweight: k x n, zeros: k / group_size x n
268268 intweight , zeros = unpack_to_8bit_signed (self .qweight , self .qzeros , self .bits , self .ZP_BIAS )
269269 if zeros is None :
@@ -275,7 +275,7 @@ def post_init(self):
275275 zeros = (zeros .to (torch .int32 ) - (2 ** (self .bits - 1 ))).to (torch .int8 )
276276 else :
277277 zeros -= 2 ** (self .bits - 1 )
278- if self .qweight .device .type != ' cpu' :
278+ if self .qweight .device .type != " cpu" :
279279 assert not self .asym
280280 if not self .asym :
281281 intweight -= 2 ** (self .bits - 1 )
@@ -285,21 +285,20 @@ def post_init(self):
285285 if self .asym :
286286 intweight = (intweight .to (torch .int32 ) - (2 ** (self .bits - 1 ))).to (torch .int8 )
287287
288-
289288 logger .debug (
290289 f"ARK repack quantized weight: K:{ intweight .shape [0 ]} , N:{ intweight .shape [1 ]} , weight_dtype:{ BITS_DTYPE_MAPPING [self .bits ]} , scale_dtype:fp32, compute_dtype:fp32, group_size:{ self .group_size } "
291290 )
292291
293- if self .qweight .device .type == ' xpu' :
294- self .sdt = ' fp16'
295- self .cdt = ' fp16'
292+ if self .qweight .device .type == " xpu" :
293+ self .sdt = " fp16"
294+ self .cdt = " fp16"
296295 scales = self .scales .to (torch .float16 ).contiguous ()
297296 else :
298- self .sdt = ' fp32'
299- self .cdt = ' fp32'
297+ self .sdt = " fp32"
298+ self .cdt = " fp32"
300299 scales = self .scales .float ().contiguous ()
301300 self .wdt = BITS_DTYPE_MAPPING [self .bits ]
302-
301+
303302 self .qweight = ark .repack_quantized_weight (
304303 intweight .contiguous (),
305304 scales ,
@@ -311,11 +310,10 @@ def post_init(self):
311310 self .wdt ,
312311 # scale_dtype
313312 self .sdt ,
314-
315313 self .asym ,
316314 self .group_size ,
317315 )
318-
316+
319317 # self.revert_wei = torch.zeros(self.infeatures, self.outfeatures, dtype=scales.dtype, device=self.qweight.device)
320318 # # print(packw, packw.device, packw.dtype)
321319 # ark.dequantize_packed_weight(
@@ -324,20 +322,20 @@ def post_init(self):
324322 self .qzeros = torch .empty (0 )
325323 self .scales = torch .empty (0 )
326324 if self .bias is not None :
327- if self .bias .device .type == ' cpu' :
325+ if self .bias .device .type == " cpu" :
328326 self .bias = self .bias .to (torch .float32 )
329327 else :
330328 self .bias = self .bias .to (torch .float16 )
331329
332330 def forward (self , x : torch .Tensor ):
333331 raw_input_dtype = x .dtype
334- if x .device .type == ' cpu' :
332+ if x .device .type == " cpu" :
335333 odt = torch .float32
336334 if raw_input_dtype != torch .float32 :
337335 x = x .to (torch .float32 )
338336 else :
339337 odt = x .dtype
340-
338+
341339 out_shape = x .shape [:- 1 ] + (self .outfeatures ,)
342340 x = x .view (- 1 , x .shape [- 1 ]) # convert xd to 2d
343341 out_2d_shape = x .shape [:- 1 ] + (self .outfeatures ,)
@@ -353,16 +351,18 @@ def forward(self, x: torch.Tensor):
353351 self .wdt , # weight_dtype
354352 self .sdt , # scale_dtype
355353 self .asym ,
356- self .group_size
354+ self .group_size ,
357355 )
358- if x .device .type == ' xpu' :
356+ if x .device .type == " xpu" :
359357 outputs = outputs + bias
360358 return outputs .to (raw_input_dtype ).view (out_shape )
361359
360+
362361class QuantLinearGPTQ (QuantLinear ):
363362 QUANT_TYPE = "ark_gptq"
364363 ZP_BIAS = 1
365364
365+
366366@torch .no_grad ()
367367def unpack_to_8bit_signed (qweight , qzeros , bits , gptq_bias = 1 ):
368368 wf = torch .tensor (list (range (0 , 32 , bits )), dtype = torch .int32 , device = qweight .device ).unsqueeze (0 )
@@ -393,6 +393,7 @@ def unpack_to_8bit_signed(qweight, qzeros, bits, gptq_bias=1):
393393
394394 return weight , zeros
395395
396+
396397# Copied from qlinear_marlin.py
397398@torch .no_grad ()
398399def dequantize_weight (qweight , qzeros , scales , bits ):
@@ -402,16 +403,18 @@ def dequantize_weight(qweight, qzeros, scales, bits):
402403 if unpacked_qzeros is not None :
403404 unpacked_qzeros = unpacked_qzeros .repeat_interleave (group_size , dim = 0 )
404405 else :
405- unpacked_qzeros = torch .full_like (scales , 8 if bits == 4 else 128 , dtype = torch .int32 , device = qweight .device )
406+ unpacked_qzeros = torch .full_like (scales , 8 if bits == 4 else 128 , dtype = torch .int32 , device = qweight .device )
406407 unpacked_qweight = (unpacked_qweight - unpacked_qzeros ) * scales
407408
408409 return unpacked_qweight , unpacked_qzeros
409410
411+
410412def ark_post_init (model ):
411413 for _ , submodule in model .named_modules ():
412414 if isinstance (submodule , QuantLinear ):
413415 submodule .post_init ()
414416
415417 return model
416418
417- __all__ = ["QuantLinear" , 'QuantLinearGPTQ' , 'QuantLinearAWQ' ]
419+
420+ __all__ = ["QuantLinear" , "QuantLinearGPTQ" , "QuantLinearAWQ" ]
0 commit comments