Skip to content

Commit 212b855

Browse files
committed
modified qbits to ark
1 parent e50f850 commit 212b855

File tree

3 files changed

+47
-56
lines changed

3 files changed

+47
-56
lines changed

auto_round/inference/backend.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ class BackendInfo:
4747
4848
Attributes:
4949
device: A list of strings representing the devices the backend supports
50-
(e.g., 'cuda', 'cpu').
50+
(e.g., 'cpu', 'xpu', 'cuda').
5151
sym: A list of booleans indicating whether the backend supports symmetric
5252
quantization for weights (True if symmetric, False if not).
5353
packing_format: A list of strings representing the packing formats used by the backend
54-
(e.g., 'triton', 'qbits').
54+
(e.g., 'ark', 'triton').
5555
bits: A list of integers specifying the bit-widths supported by the backend
5656
for weight quantization (e.g., [2, 4, 8]).
5757
group_size: An optional list of integers specifying the group sizes supported
@@ -430,51 +430,51 @@ def fp8_static_scheme_checker(
430430
requirements=["autoawq", "transformers<4.57.0"],
431431
)
432432

433-
BackendInfos["qbits"] = BackendInfo(
434-
device=["cpu"],
435-
sym=[True, False],
433+
BackendInfos["auto_round_kernel"] = BackendInfo(
434+
device=["cpu", "xpu"],
435+
sym=[True],
436436
packing_format=GPTQ_FORMAT_NO_ZP,
437437
bits=[2, 4, 8],
438438
group_size=None,
439-
priority=1,
439+
priority=0,
440440
checkers=[],
441-
alias=["itrex", "qbits"],
442-
compute_dtype=["float16", "bfloat16"],
441+
alias=["ark"],
442+
compute_dtype=["float32", "float16"],
443443
data_type=["int"],
444444
act_bits=WOQ_DEFAULT_ACT_BITS,
445-
requirements=["torch<2.7.0", "intel-extension-for-transformers"],
445+
requirements=["torch>=2.9.0"],
446446
)
447447

448-
BackendInfos["qbits_zp"] = BackendInfo(
449-
device=["cpu"],
450-
sym=[True, False],
448+
BackendInfos["auto_round_kernel_zp"] = BackendInfo(
449+
device=["cpu", "xpu"],
450+
sym=[True],
451451
packing_format=GPTQ_FORMAT,
452452
bits=[2, 4, 8],
453453
group_size=None,
454-
compute_dtype=["float16", "bfloat16"],
454+
priority=0,
455+
checkers=[],
456+
alias=["ark"],
457+
compute_dtype=["float32", "float16"],
455458
data_type=["int"],
456459
act_bits=WOQ_DEFAULT_ACT_BITS,
457-
priority=1,
458-
checkers=[],
459-
alias=["itrex", "qbits"],
460-
requirements=["torch<2.7.0", "intel-extension-for-transformers"],
460+
requirements=["torch>=2.9.0"],
461461
)
462462

463-
464-
BackendInfos["qbits_awq"] = BackendInfo(
463+
BackendInfos["auto_round_kernel_awq"] = BackendInfo(
465464
device=["cpu"],
466465
sym=[True, False],
467466
packing_format=AWQ_FORMAT,
468467
bits=[2, 4, 8],
469468
group_size=None,
470-
compute_dtype=["float16", "bfloat16"],
469+
priority=0,
470+
checkers=[],
471+
alias=["ark"],
472+
compute_dtype=["float32", "float16"],
471473
data_type=["int"],
472474
act_bits=WOQ_DEFAULT_ACT_BITS,
473-
priority=1,
474-
checkers=[],
475-
alias=["itrex", "qbits"],
476-
requirements=["torch<2.7.0", "intel-extension-for-transformers"],
475+
requirements=["torch>=2.9.0"],
477476
)
477+
478478
BackendInfos["ipex_gptq"] = BackendInfo(
479479
device=["cpu", "xpu"],
480480
sym=[True, False],
@@ -601,12 +601,12 @@ def dynamic_import_inference_linear(backend, config):
601601
"""Dynamically imports and returns the appropriate QuantLinear class based on the given backend.
602602
603603
This function dynamically loads the correct `QuantLinear` class based on the backend and quantization
604-
configuration (e.g., qbits, marlin, hpu, gptq, awq, auto_round). It imports specific modules or raises
604+
configuration (e.g., ark, marlin, hpu, gptq, awq). It imports specific modules or raises
605605
errors if the required packages are not installed or the environment is not set up.
606606
607607
Args:
608608
backend (str):
609-
The backend to be used for quantization (e.g., 'qbits', 'marlin', 'hpu', 'gptq', 'awq', 'auto_round').
609+
The backend to be used for quantization (e.g., 'ark', 'marlin', 'hpu', 'gptq', 'awq').
610610
config (QuantizationScheme):
611611
The quantization configuration containing parameters like bits, group_size, and sym.
612612
@@ -616,7 +616,7 @@ def dynamic_import_inference_linear(backend, config):
616616
617617
Raises:
618618
ImportError:
619-
If required modules are missing for a backend (e.g., Intel Extension, GPTQ, auto_awq).
619+
If required modules are missing for a backend (e.g., ark, GPTQ, auto_awq).
620620
"""
621621
bits, group_size, sym = config["bits"], config["group_size"], config["sym"]
622622

@@ -629,26 +629,22 @@ def dynamic_import_inference_linear(backend, config):
629629
if "torch_nvfp4" in backend:
630630
return ar_qmodules.NVFP4QuantLinear
631631

632-
if "qbits" in backend:
632+
if "auto_round_kernel" in backend or 'ark' in backend:
633633
try:
634-
from intel_extension_for_transformers import qbits # pylint: disable=E0401
634+
import auto_round_kernel as ark # pylint: disable=E0401
635635
except Exception as e:
636636
raise ImportError(
637-
"Please install Intel Extension for Transformers via 'pip install "
638-
"intel-extension-for-transformers' to inference on X86 CPU"
637+
"Please install auto_round_kernel version for CPU/XPU"
639638
)
639+
import auto_round_extension.kernel.qlinear as qlinear
640+
640641
if "zp" in backend:
641-
import auto_round_extension.qbits.qlinear_qbits_gptq as qlinear_qbits_gptq
642-
643-
return qlinear_qbits_gptq.QuantLinear
642+
return qlinear.QuantLinearGPTQ
644643
elif "awq" in backend:
645-
import auto_round_extension.qbits.qbits_awq as qlinear_qbits_awq
646-
647-
return qlinear_qbits_awq.QuantLinear
644+
return qlinear.QuantLinearAWQ
648645
else: # auto_round must be at the end
649-
import auto_round_extension.qbits.qlinear_qbits as qlinear_qbits_autoround
646+
return qlinear.QuantLinear
650647

651-
return qlinear_qbits_autoround.QuantLinear
652648
if "ipex_gptq" in backend:
653649
from auto_round_extension.ipex.qlinear_ipex_gptq import QuantLinear
654650

auto_round/inference/convert_model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from auto_round.inference.utils import _expand_regex_config
3232
from auto_round.logger import logger
3333
from auto_round.schemes import QuantizationScheme
34-
from auto_round.special_model_handler import _handle_moe_model
3534
from auto_round.utils import (
3635
SUPPORTED_LAYER_TYPES,
3736
check_start_with_block_name,
@@ -395,9 +394,9 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features)
395394
bias = layer.bias is not None
396395

397396
# Special handling for AWQ layers
398-
from auto_round_extension.qbits.qbits_awq import QuantLinear as QBitsAWQQuantLinear
397+
from auto_round_extension.ark.qlinear import QuantLinearAWQ
399398

400-
if "awq" in layer_backend and isinstance(QuantLinear, QBitsAWQQuantLinear):
399+
if "awq" in layer_backend and isinstance(QuantLinear, QuantLinearAWQ):
401400
return QuantLinear.from_linear(
402401
layer, config["bits"], config["group_size"], init_only=True, has_zero_points=not config["sym"]
403402
)
@@ -474,7 +473,6 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None:
474473
need_gptqmodel_init = False
475474
need_ipex_itrex_init = False
476475
used_gptq_exllamav2 = False
477-
478476
# Determine which backends require post-init
479477
for backend in used_backends:
480478
if backend.startswith("auto_gptq"):
@@ -483,7 +481,7 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None:
483481
used_gptq_exllamav2 = True
484482
elif backend.startswith("gptqmodel"):
485483
need_gptqmodel_init = True
486-
elif backend.startswith(("ipex", "qbit")):
484+
elif backend.startswith(("ipex", "auto_round_kernel")):
487485
need_ipex_itrex_init = True
488486

489487
# AutoGPTQ post-init
@@ -503,7 +501,7 @@ def post_init(model: torch.nn.Module, used_backends: list[str]) -> None:
503501
message = "repacking to CPU/XPU format"
504502
layers = [] ## ipex post_init will add one more layer
505503
for n, m in model.named_modules():
506-
if hasattr(m, "QUANT_TYPE") and ("qbits" in m.QUANT_TYPE or "ipex" in m.QUANT_TYPE):
504+
if hasattr(m, "QUANT_TYPE") and ("ark" in m.QUANT_TYPE or "ipex" in m.QUANT_TYPE):
507505
layers.append(m)
508506

509507
for layer in tqdm(layers, desc=message, total=len(layers), leave=True):
@@ -583,9 +581,6 @@ def convert_hf_model(model: nn.Module, target_device: str = "cpu") -> tuple[nn.M
583581
elif packing_format == "auto_round:gptq":
584582
packing_format = "auto_round:auto_gptq"
585583

586-
# Preprocess model before replace layers
587-
model = _handle_moe_model(model)
588-
589584
# Replace layers with quantized versions
590585
layer_configs = get_layer_config(model, quantization_config)
591586
used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, packing_format)

test/test_cpu/test_torch_backend.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class TestAutoRoundTorchBackend(unittest.TestCase):
2727

2828
@classmethod
2929
def setUpClass(self):
30-
self.model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
30+
self.model_name = "facebook/opt-125m"
3131
self.save_folder = "./saved"
3232
self.llm_dataloader = LLMDataLoader()
3333

@@ -63,7 +63,7 @@ def tearDownClass(self):
6363
shutil.rmtree("runs", ignore_errors=True)
6464

6565
def test_torch_4bits_asym(self):
66-
model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True)
66+
model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype="auto", trust_remote_code=True)
6767
tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
6868
bits, group_size, sym = 4, 128, False
6969
autoround = AutoRound(
@@ -81,7 +81,7 @@ def test_torch_4bits_asym(self):
8181

8282
quantization_config = AutoRoundConfig(backend="torch")
8383
model = AutoModelForCausalLM.from_pretrained(
84-
quantized_model_path, torch_dtype=torch.float16, device_map="cpu", quantization_config=quantization_config
84+
quantized_model_path, dtype=torch.float16, device_map="cpu", quantization_config=quantization_config
8585
)
8686

8787
tokenizer = AutoTokenizer.from_pretrained(self.save_folder)
@@ -92,7 +92,7 @@ def test_torch_4bits_asym(self):
9292
torch.cuda.empty_cache()
9393

9494
model = AutoModelForCausalLM.from_pretrained(
95-
self.save_folder, torch_dtype=torch.bfloat16, device_map="cpu", quantization_config=quantization_config
95+
self.save_folder, dtype=torch.bfloat16, device_map="cpu", quantization_config=quantization_config
9696
)
9797

9898
tokenizer = AutoTokenizer.from_pretrained(self.save_folder)
@@ -104,9 +104,9 @@ def test_torch_4bits_asym(self):
104104
shutil.rmtree("./saved", ignore_errors=True)
105105

106106
def test_torch_4bits_sym(self):
107-
model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True)
107+
model = AutoModelForCausalLM.from_pretrained(self.model_name, dtype="auto", trust_remote_code=True)
108108
tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
109-
bits, group_size, sym = 4, 128, True
109+
bits, group_size, sym = 4, 32, True
110110
autoround = AutoRound(
111111
model,
112112
tokenizer,
@@ -122,12 +122,12 @@ def test_torch_4bits_sym(self):
122122

123123
quantization_config = AutoRoundConfig(backend="torch")
124124
model = AutoModelForCausalLM.from_pretrained(
125-
quantized_model_path, torch_dtype=torch.float16, device_map="cpu", quantization_config=quantization_config
125+
quantized_model_path, dtype=torch.float16, device_map="auto", quantization_config=quantization_config
126126
)
127127

128128
tokenizer = AutoTokenizer.from_pretrained(self.save_folder)
129129
self.model_infer(model, tokenizer)
130-
result = simple_evaluate_user_model(model, tokenizer, batch_size=16, tasks="lambada_openai", limit=10)
130+
result = simple_evaluate_user_model(model, tokenizer, batch_size=32, tasks="lambada_openai", limit=1000)
131131
print(result["results"]["lambada_openai"]["acc,none"])
132132
self.assertGreater(result["results"]["lambada_openai"]["acc,none"], 0.28)
133133
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)