Skip to content

Commit e50f850

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6345f86 commit e50f850

File tree

2 files changed

+32
-28
lines changed

2 files changed

+32
-28
lines changed

auto_round_extension/ark/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from auto_round_extension.ark.qlinear import QuantLinear, QuantLinearGPTQ, QuantLinearAWQ
16+
1617
qlinear_classes = (QuantLinear, QuantLinearGPTQ)
1718

1819
awq_classes = (QuantLinearAWQ,)

auto_round_extension/ark/qlinear.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414

1515
import math
16+
1617
import torch
1718
import torch.nn as nn
19+
1820
from auto_round.utils import convert_dtype_torch2str, logger
1921

2022
try:
2123
import auto_round_kernel as ark
24+
2225
ARK_INSTALLED = True
2326
except:
2427
ARK_INSTALLED = False
@@ -31,6 +34,7 @@
3134

3235
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
3336

37+
3438
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
3539
shifts = torch.arange(0, 32, bits, device="cpu")
3640

@@ -51,6 +55,7 @@ def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
5155

5256
return iweights, izeros
5357

58+
5459
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
5560
reverse_order_tensor = torch.arange(
5661
iweights.shape[-1],
@@ -66,14 +71,13 @@ def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
6671
iweights = iweights[:, reverse_order_tensor]
6772
return iweights, izeros
6873

74+
6975
class QuantLinearAWQ(nn.Module):
7076
QUANT_TYPE = "ark_awq"
7177

7278
def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
7379
super().__init__()
74-
assert (
75-
ARK_INSTALLED
76-
), "Please install auto_round_kernel package."
80+
assert ARK_INSTALLED, "Please install auto_round_kernel package."
7781

7882
self.use_bf16 = ark.check_isa_supported("AMX")
7983

@@ -162,9 +166,7 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze
162166

163167
@torch.no_grad()
164168
def forward(self, x):
165-
assert ARK_INSTALLED, (
166-
"ARK kernels could not be loaded. "
167-
)
169+
assert ARK_INSTALLED, "ARK kernels could not be loaded. "
168170

169171
input_dtype = x.dtype
170172
out_shape = x.shape[:-1] + (self.out_features,)
@@ -200,6 +202,7 @@ def extra_repr(self) -> str:
200202
self.group_size,
201203
)
202204

205+
203206
class QuantLinear(nn.Module):
204207
QUANT_TYPE = "ark_gptq_nozp"
205208
ZP_BIAS = 0
@@ -220,9 +223,7 @@ def __init__(
220223

221224
if bits not in [2, 4, 8]:
222225
raise NotImplementedError("Only 2, 4,8 bits are supported for ARK.")
223-
assert (
224-
ARK_INSTALLED
225-
), "Please install auto_round_kernel."
226+
assert ARK_INSTALLED, "Please install auto_round_kernel."
226227

227228
self.infeatures = infeatures
228229
self.outfeatures = outfeatures
@@ -261,9 +262,8 @@ def __init__(
261262
self.kernel_switch_threshold = kernel_switch_threshold
262263
self.trainable = trainable
263264

264-
265265
def post_init(self):
266-
assert self.qweight.device.type in ["cpu", 'xpu']
266+
assert self.qweight.device.type in ["cpu", "xpu"]
267267
# intweight: k x n, zeros: k / group_size x n
268268
intweight, zeros = unpack_to_8bit_signed(self.qweight, self.qzeros, self.bits, self.ZP_BIAS)
269269
if zeros is None:
@@ -275,7 +275,7 @@ def post_init(self):
275275
zeros = (zeros.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8)
276276
else:
277277
zeros -= 2 ** (self.bits - 1)
278-
if self.qweight.device.type != 'cpu':
278+
if self.qweight.device.type != "cpu":
279279
assert not self.asym
280280
if not self.asym:
281281
intweight -= 2 ** (self.bits - 1)
@@ -285,21 +285,20 @@ def post_init(self):
285285
if self.asym:
286286
intweight = (intweight.to(torch.int32) - (2 ** (self.bits - 1))).to(torch.int8)
287287

288-
289288
logger.debug(
290289
f"ARK repack quantized weight: K:{intweight.shape[0]}, N:{intweight.shape[1]}, weight_dtype:{BITS_DTYPE_MAPPING[self.bits]}, scale_dtype:fp32, compute_dtype:fp32, group_size:{self.group_size}"
291290
)
292291

293-
if self.qweight.device.type == 'xpu':
294-
self.sdt = 'fp16'
295-
self.cdt = 'fp16'
292+
if self.qweight.device.type == "xpu":
293+
self.sdt = "fp16"
294+
self.cdt = "fp16"
296295
scales = self.scales.to(torch.float16).contiguous()
297296
else:
298-
self.sdt = 'fp32'
299-
self.cdt = 'fp32'
297+
self.sdt = "fp32"
298+
self.cdt = "fp32"
300299
scales = self.scales.float().contiguous()
301300
self.wdt = BITS_DTYPE_MAPPING[self.bits]
302-
301+
303302
self.qweight = ark.repack_quantized_weight(
304303
intweight.contiguous(),
305304
scales,
@@ -311,11 +310,10 @@ def post_init(self):
311310
self.wdt,
312311
# scale_dtype
313312
self.sdt,
314-
315313
self.asym,
316314
self.group_size,
317315
)
318-
316+
319317
# self.revert_wei = torch.zeros(self.infeatures, self.outfeatures, dtype=scales.dtype, device=self.qweight.device)
320318
# # print(packw, packw.device, packw.dtype)
321319
# ark.dequantize_packed_weight(
@@ -324,20 +322,20 @@ def post_init(self):
324322
self.qzeros = torch.empty(0)
325323
self.scales = torch.empty(0)
326324
if self.bias is not None:
327-
if self.bias.device.type == 'cpu':
325+
if self.bias.device.type == "cpu":
328326
self.bias = self.bias.to(torch.float32)
329327
else:
330328
self.bias = self.bias.to(torch.float16)
331329

332330
def forward(self, x: torch.Tensor):
333331
raw_input_dtype = x.dtype
334-
if x.device.type == 'cpu':
332+
if x.device.type == "cpu":
335333
odt = torch.float32
336334
if raw_input_dtype != torch.float32:
337335
x = x.to(torch.float32)
338336
else:
339337
odt = x.dtype
340-
338+
341339
out_shape = x.shape[:-1] + (self.outfeatures,)
342340
x = x.view(-1, x.shape[-1]) # convert xd to 2d
343341
out_2d_shape = x.shape[:-1] + (self.outfeatures,)
@@ -353,16 +351,18 @@ def forward(self, x: torch.Tensor):
353351
self.wdt, # weight_dtype
354352
self.sdt, # scale_dtype
355353
self.asym,
356-
self.group_size
354+
self.group_size,
357355
)
358-
if x.device.type == 'xpu':
356+
if x.device.type == "xpu":
359357
outputs = outputs + bias
360358
return outputs.to(raw_input_dtype).view(out_shape)
361359

360+
362361
class QuantLinearGPTQ(QuantLinear):
363362
QUANT_TYPE = "ark_gptq"
364363
ZP_BIAS = 1
365364

365+
366366
@torch.no_grad()
367367
def unpack_to_8bit_signed(qweight, qzeros, bits, gptq_bias=1):
368368
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qweight.device).unsqueeze(0)
@@ -393,6 +393,7 @@ def unpack_to_8bit_signed(qweight, qzeros, bits, gptq_bias=1):
393393

394394
return weight, zeros
395395

396+
396397
# Copied from qlinear_marlin.py
397398
@torch.no_grad()
398399
def dequantize_weight(qweight, qzeros, scales, bits):
@@ -402,16 +403,18 @@ def dequantize_weight(qweight, qzeros, scales, bits):
402403
if unpacked_qzeros is not None:
403404
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
404405
else:
405-
unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32, device = qweight.device)
406+
unpacked_qzeros = torch.full_like(scales, 8 if bits == 4 else 128, dtype=torch.int32, device=qweight.device)
406407
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
407408

408409
return unpacked_qweight, unpacked_qzeros
409410

411+
410412
def ark_post_init(model):
411413
for _, submodule in model.named_modules():
412414
if isinstance(submodule, QuantLinear):
413415
submodule.post_init()
414416

415417
return model
416418

417-
__all__ = ["QuantLinear", 'QuantLinearGPTQ', 'QuantLinearAWQ']
419+
420+
__all__ = ["QuantLinear", "QuantLinearGPTQ", "QuantLinearAWQ"]

0 commit comments

Comments
 (0)