Skip to content

Commit 7907c32

Browse files
authored
improve accuracy for 2bit with auto-round-best (#1078)
1 parent 61bcfdb commit 7907c32

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

auto_round/autoround.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class AutoRound:
4141
the quantization of LLMs." arXiv:2309.05516 (2023).
4242
4343
Attributes:
44-
model (torch.nn.Module): The loaded PyTorch model in eval mode.
44+
model (torch.nn.Module | str): The loaded PyTorch model in eval mode.
4545
tokenizer: Tokenizer used to prepare input text for calibration/tuning.
4646
platform (str): The platform to load pretrained moded, options: ["hf", "model_scope"]
4747
bits (int): Weight quantization bits.
@@ -85,6 +85,8 @@ def __new__(
8585
enable_adam: bool = False,
8686
# for MLLM and Diffusion
8787
extra_config: ExtraConfig = None,
88+
enable_alg_ext: bool = False,
89+
disable_opt_rtn: bool = False,
8890
low_cpu_mem_usage: bool = False,
8991
**kwargs,
9092
) -> BaseCompressor:

auto_round/compressors/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,13 @@ def __init__(
341341
if self.iters == 0:
342342
self.lr = 5e-3
343343
else:
344-
self.lr = lr or (1.0 / self.iters) # must place after iter setting
344+
if not lr:
345+
# TODO need to check 3/4 bits lr setting for auto-round-best
346+
self.lr = 2.0 / self.iters if (self.iters >= 1000 and self.bits == 2) else 1.0 / self.iters
347+
if self.iters >= 1000 and self.bits == 2:
348+
logger.info("set the lr to 2.0/iters for better accuracy")
349+
else:
350+
self.lr = lr
345351
self.minmax_lr = minmax_lr or self.lr
346352
self.enable_alg_ext = enable_alg_ext
347353
self.sampler = sampler
@@ -510,7 +516,9 @@ def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None:
510516
else:
511517
raise TypeError(f"device_map should be [str, torch.device, int, dict], but got {type(device_map)}")
512518

513-
def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> QuantizationScheme:
519+
def _parse_and_set_scheme(
520+
self, scheme: Union[str, dict, QuantizationScheme], kwargs
521+
) -> tuple[QuantizationScheme, bool]:
514522
"""Parse and set the quantization scheme."""
515523

516524
def _parse_and_set(scheme, kwargs):

docs/alg_202508.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ We use **lm-eval** for evaluation. For LLaMA, we enabled `add_bos_token` and
33
in [modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L52C1-L52C40)
44
to stabilize accuracy during evaluation. All other settings follow the default configurations of AutoRound and lm-eval.
55

6-
| Qwen3-8B W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
7-
|:-------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
8-
| AutoRound | 0.4373 | 0.4019 | 0.4437 | 0.4215 | 0.4826 | 0.5474 | 0.2630 | 0.3072 | 0.6314 |
9-
| AutoRound+alg_ext | 0.4787 | 0.4275 | 0.4516 | 0.5944 | 0.5181 | 0.5773 | 0.2807 | 0.3305 | 0.6496 |
6+
| Qwen3-8B W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
7+
|:------------------------------|:------:|:-------------:|:---------:|:------:|:--------------:|:------:|:-------:|:--------------:|:----------:|
8+
| AutoRound | 0.4373 | 0.4019 | 0.4437 | 0.4215 | 0.4826 | 0.5474 | 0.2630 | 0.3072 | 0.6314 |
9+
| AutoRound+alg_ext | 0.4787 | 0.4275 | 0.4516 | 0.5944 | 0.5181 | 0.5773 | 0.2807 | 0.3305 | 0.6496 |
10+
| AutoRoundBest+alg_ext lr 2e-3 | 0.4937 | 0.4505 | 0.474 | 0.5906 | 0.5556 | 0.6028 | 0.3127 | 0.3109 | 0.6527 |
1011

11-
| Llama3.1-8B-Instruct W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
12-
|:---------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
13-
| AutoRound | 0.3820 | 0.3635 | 0.4562 | 0.1622 | 0.5069 | 0.4411 | 0.1661 | 0.3207 | 0.6393 |
14-
| AutoRound+alg_ext | 0.4166 | 0.3712 | 0.4729 | 0.2039 | 0.5946 | 0.4981 | 0.2163 | 0.3011 | 0.6748 |
12+
| Llama3.1-8B-Instruct W2G64 | Avg. | arc_challenge | hellaswag | gsm8k | lambada_openai | mmlu | mmlupro | truthfulqa_mc1 | winogrande |
13+
|:------------------------------|:------:|:-------------:|:---------:|:------:|:--------------:|:------:|:-------:|:--------------:|:----------:|
14+
| AutoRound | 0.3820 | 0.3635 | 0.4562 | 0.1622 | 0.5069 | 0.4411 | 0.1661 | 0.3207 | 0.6393 |
15+
| AutoRound+alg_ext | 0.4166 | 0.3712 | 0.4729 | 0.2039 | 0.5946 | 0.4981 | 0.2163 | 0.3011 | 0.6748 |
16+
| AutoRoundBest+alg_ext lr 2e-3 | 0.4539 | 0.4138 | 0.4999 | 0.3071 | 0.6233 | 0.5279 | 0.2364 | 0.3231 | 0.6993 |

0 commit comments

Comments
 (0)