Skip to content

Commit e18f9b6

Browse files
authored
update alg_ext and add ut (#1064)
* update alg_ext and add ut Signed-off-by: n1ck-guo <[email protected]>
1 parent b1b60e4 commit e18f9b6

File tree

5 files changed

+51
-21
lines changed

5 files changed

+51
-21
lines changed

auto_round/alg_ext.abi3.so

-16.1 KB
Binary file not shown.

auto_round/compressors/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,18 +311,12 @@ def __init__(
311311
if device_map is None:
312312
device_map = 0
313313

314-
self.enable_torch_compile = enable_torch_compile
315-
self._adjust_torch_compile(enable_torch_compile)
316-
317314
self.device_map = device_map
318315
if isinstance(self.device_map, str):
319316
self.device_map = self.device_map.replace(" ", "")
320317

321318
self.device_list = parse_available_devices(device_map)
322319

323-
if isinstance(scheme, AutoScheme):
324-
self.layer_config = self._gen_auto_scheme(model, scheme, dataset, self.device_map)
325-
326320
# Set device, must place after model loading
327321
self.device = get_major_device(device_map)
328322
set_non_auto_device_map(self.model, self.device_map)
@@ -387,10 +381,17 @@ def __init__(
387381
self.batch_dim = None
388382
self.infer_bs_coeff = 1
389383

384+
# after setting iters
385+
self.enable_torch_compile = enable_torch_compile
386+
self._adjust_torch_compile(enable_torch_compile)
387+
390388
self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward
391389
self._check_configs()
392390
torch.set_printoptions(precision=3, sci_mode=True)
393391

392+
if isinstance(scheme, AutoScheme):
393+
self.layer_config = self._gen_auto_scheme(model, scheme, dataset, self.device_map)
394+
394395
if is_hpex_available():
395396
logger.info("habana_frameworks is available, import htcore explicitly.")
396397
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
@@ -632,6 +633,7 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
632633
and not is_debug_mode()
633634
and "fp8" not in self.data_type
634635
and "fp8" not in self.act_data_type
636+
and self.iters > 0
635637
):
636638
logger.info(
637639
"'enable_torch_compile' is set to `False` by default. "

auto_round/data_type/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,20 @@ def _is_mlp_module(module: Module):
264264
# already fused/treated as one layer
265265
if hasattr(submodule, "qkv_proj"):
266266
return
267+
268+
q_global_scale = getattr(submodule.q_proj, global_scale_name, max_value_tensor)
269+
q_global_scale = max_value_tensor if q_global_scale is None else q_global_scale
270+
k_global_scale = getattr(submodule.k_proj, global_scale_name, max_value_tensor)
271+
k_global_scale = max_value_tensor if k_global_scale is None else k_global_scale
272+
v_global_scale = getattr(submodule.v_proj, global_scale_name, max_value_tensor)
273+
v_global_scale = max_value_tensor if v_global_scale is None else v_global_scale
274+
267275
global_scale = torch.min(
268276
torch.cat(
269277
(
270-
getattr(submodule.q_proj, global_scale_name, max_value_tensor).reshape(1),
271-
getattr(submodule.k_proj, global_scale_name, max_value_tensor).reshape(1),
272-
getattr(submodule.v_proj, global_scale_name, max_value_tensor).reshape(1),
278+
q_global_scale.reshape(1),
279+
k_global_scale.reshape(1),
280+
v_global_scale.reshape(1),
273281
)
274282
)
275283
).reshape([1])

test/test_cpu/test_alg_ext.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import copy
2+
import shutil
3+
import sys
4+
import unittest
5+
6+
from parameterized import parameterized
7+
8+
sys.path.insert(0, "../..")
9+
10+
from auto_round import AutoRound
11+
12+
13+
class TestAlgExt(unittest.TestCase):
14+
def test_alg_ext(self):
15+
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
16+
ar = AutoRound(model_name, scheme="W2A16", iters=1, nsamples=1, enable_alg_ext=True)
17+
ar.quantize()
18+
19+
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B"
20+
ar = AutoRound(model_name, scheme="gguf:q4_k_s", iters=1, nsamples=1, enable_alg_ext=True)
21+
ar.quantize()
22+
23+
def test_alg_ext_import(self):
24+
from auto_round.alg_ext import wrapper_autoround
25+
26+
def test_all_support_dtype(self):
27+
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
28+
for scheme in ["MXFP4", "NVFP4", "W2A16G64"]:
29+
ar = AutoRound(
30+
model_name, scheme=scheme, iters=1, nsamples=1, enable_alg_ext=True, enable_torch_compile=True
31+
)
32+
ar.quantize()

test/test_cpu/test_autoround.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -689,18 +689,6 @@ def test_mixed_bit_setting(self):
689689
):
690690
raise ValueError("mixed bits is not correct")
691691

692-
def test_alg_ext(self):
693-
model_name = "/tf_dataset/auto_round/models/facebook/opt-125m"
694-
ar = AutoRound(model_name, scheme="W2A16", iters=1, nsamples=1, enable_alg_ext=True)
695-
ar.quantize()
696-
697-
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B"
698-
ar = AutoRound(model_name, scheme="gguf:q4_k_s", iters=1, nsamples=1, enable_alg_ext=True)
699-
ar.quantize()
700-
701-
def test_alg_ext_import(self):
702-
from auto_round.alg_ext import wrapper_autoround
703-
704692
def test_invalid_layer_config(self):
705693
with self.assertRaises(ValueError):
706694
layer_config = {"model.decoder.layers.2.self_attnx": {"bits": 2}}

0 commit comments

Comments
 (0)