Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -1916,6 +1916,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool:
"""
# Get the names of layers in quantization blocks
supported_types = self.supported_types
dynamic_config = {}
Copy link
Contributor

@wenhuach21 wenhuach21 Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest avoiding the term 'dynamic config,' which is not used in academic literature and was coined by Unlosh, and personally, I think it's a little confusing, as there are static activation quantization and dynamic activation quantization

layers_in_blocks = get_layer_names_in_block(
self.model, supported_types, self.quant_block_list, self.inner_supported_types
)
Expand Down Expand Up @@ -1944,6 +1945,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool:
matched_names.append(layer_name)
if len(matched_names) > 0:
val = layer_config[name]
dynamic_config[name] = val # keep regex config
layer_config.pop(name)
for match_name in matched_names:
layer_config[match_name] = val
Expand Down Expand Up @@ -2033,7 +2035,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool:
need_to_quantize_lm_head = self._check_need_to_quantize_lm_head_embedding()
if need_to_quantize_lm_head:
has_qlayer_outside_block = True

self.dynamic_config = dynamic_config
# Return whether there are quantized layers outside the blocks
return has_qlayer_outside_block

Expand Down Expand Up @@ -3125,6 +3127,7 @@ def save_quantized(
"act_data_type",
"super_bits",
"super_group_size",
"dynamic_config",
]
if isinstance(self.dataset, str):
serialization_keys.append("dataset")
Expand Down
28 changes: 27 additions & 1 deletion auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict

import threadpoolctl as tctl

Expand Down Expand Up @@ -57,6 +58,7 @@
get_block_names,
get_module,
set_module,
to_standard_regex,
)

BLOCK_PATTERNS = [ ## copy from transformers optimum
Expand All @@ -67,6 +69,29 @@
]


def convert_to_autogptq_dynamic(dynamic_config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""
Convert AutoRound-style dynamic_config into AutoGPTQ-style QuantizerConfig.dynamic.

Rules:
- bits < 16 -> quantize -> positive match `+:regex`
- bits == 16 -> skip quantize -> negative match `-:regex`
"""
converted = {}
for name, cfg in dynamic_config.items():
bits = cfg.get("bits")
regex = to_standard_regex(name)

if bits is None:
continue # ignore invalid entries
elif bits < 16:
converted[f"r'+:{regex}'"] = {"bits": bits, **{k: v for k, v in cfg.items() if k != "bits"}}
else:
# skip quantization
converted[f"r'-:{regex}'"] = {}
return converted


def pack_layer(name, model, backend, device=None):
if name == "lm_head": ##dese not support lm-head
return
Expand Down Expand Up @@ -155,7 +180,8 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
logger.error("auto-gptq format may not support loading this quantized model")
quantization_config["block_name_to_quantize"] = common_prefix
quantization_config.pop("to_quant_block_names", None)

dynamic_config = quantization_config.pop("dynamic_config")
quantization_config["dynamic"] = convert_to_autogptq_dynamic(dynamic_config)
## as layers maybe already packed, we need to check in layer_config
layer_config = kwargs["layer_config"]
for n, m in model.named_modules():
Expand Down
11 changes: 11 additions & 0 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_nv_fp,
is_standard_fp,
set_module,
to_standard_regex,
)


Expand Down Expand Up @@ -316,6 +317,10 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"]
extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"]
extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"]
extra_config[layer_name]["act_bits"] = layer_config[layer_name]["act_bits"]
extra_config[layer_name]["act_data_type"] = layer_config[layer_name]["act_data_type"]
extra_config[layer_name]["act_group_size"] = layer_config[layer_name]["act_group_size"]
extra_config[layer_name]["act_sym"] = layer_config[layer_name]["act_sym"]
elif layer_config[layer_name]["in_blocks"] or (
block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize)
):
Expand All @@ -327,6 +332,12 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
for key in neq_keys:
if layer_config[layer_name][key] is not None:
extra_config[layer_name][key] = layer_config[layer_name][key]

dynamic_config = quantization_config.pop("dynamic_config")
if name in dynamic_config.keys():
regex_name = to_standard_regex(name)
extra_config[regex_name] = dynamic_config[name]

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config
names = list(layer_config.keys())
Expand Down
8 changes: 2 additions & 6 deletions auto_round/export/export_to_llmcompressor/export_to_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
check_to_quantized,
copy_python_files_from_model_cache,
filter_quantization_config,
generate_ignore_regex_list,
get_block_names,
get_module,
is_mx_fp,
Expand Down Expand Up @@ -198,12 +199,7 @@ def wrapper(name):
for _ in executor.map(wrapper, names):
pass

# TODO fix the ignore re match issue, compile with fp8 & int8 config
ignore = ["lm_head"]
for layer_name in layer_config:
if layer_config[layer_name]["bits"] > 8: ## find ignore layers
ignore.append(layer_name)
ignore = list(set(ignore))
# ignore = generate_ignore_regex_list() ## check

# get llm-compressor format config
check_compressed_tensors_supported()
Expand Down
52 changes: 52 additions & 0 deletions auto_round/export/export_to_llmcompressor/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List

from auto_round.utils import matches_any_regex, to_standard_regex


def generate_ignore_regex_list(dynamic_config: Dict[str, Dict], layer_config: Dict[str, Dict]) -> List[str]:
"""
Generate ignore regex list for llm_compressor based on dynamic_config and layer_config.

Rules:
1. Any layer in dynamic_config with bits >= 16 is ignored.
2. Any layer in layer_config with bits >= 16 is ignored if not already included.
3. Output regex patterns are normalized for llm_compressor ('re:...' style).

Args:
dynamic_config (Dict[str, Dict]): dynamic quantization config
layer_config (Dict[str, Dict]): layer-wise quantization config

Returns:
List[str]: List of regex patterns to ignore during quantization.
"""
prefix = "re:"
ignore_regex: List[str] = []

# Step 1: Add dynamic_config keys with bits >= 16
for key, cfg in dynamic_config.items():
bits = cfg.get("bits")
if bits > 8:
ignore_regex.append(prefix + to_standard_regex(key))

# Step 2: Add layer_config keys if bits >= 16 and not already included
for key, cfg in layer_config.items():
bits = cfg.get("bits")

if not matches_any_regex(key, ignore_regex, prefix):
ignore_regex.append(key)

return ignore_regex
71 changes: 70 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable, List, Tuple, Union

import cpuinfo
import torch
Expand Down Expand Up @@ -2687,3 +2687,72 @@ def copy_python_files_from_model_cache(model, save_path: str):
if file.endswith(".py") and os.path.isfile(full_file_name):
logger.debug(f"Transferring {full_file_name} to {save_path}")
shutil.copy(full_file_name, save_path)


def to_standard_regex(pattern: str) -> str:
"""
Convert a user-specified string into a standardized regex for layer matching.

Rules:
- If the pattern already contains regex tokens ('.*', '^', '$', etc.),
keep them as-is.
- Otherwise, wrap the pattern with `.*` on both sides to allow substring matching.
- Always ensure the returned regex is valid (compilable by re).

Examples:
>>> to_standard_regex("model.embed_tokens")
'.*model\\.embed_tokens.*'
>>> to_standard_regex("mlp.gate")
'.*mlp\\.gate.*'
>>> to_standard_regex("mlp.gate$")
'.*mlp\\.gate$'
>>> to_standard_regex("mlp.*gate")
'.*mlp.*gate.*'
"""
# Heuristic: if pattern contains regex meta characters, assume partial regex
meta_chars = {".*", "^", "$", "|", "(", ")", "[", "]", "?", "+"}
has_regex = any(tok in pattern for tok in meta_chars)
if not has_regex:
# Escape literal dots, etc., and wrap with .* for substring matching
pattern = re.escape(pattern)
regex = f".*{pattern}.*"
else:
# Only escape bare dots that are not already part of regex constructs
# Avoid double escaping .* sequences
tmp = []
i = 0
while i < len(pattern):
if pattern[i] == ".":
if i + 1 < len(pattern) and pattern[i + 1] == "*":
tmp.append(".*") # keep regex token
i += 2
continue
else:
tmp.append("\\.") # escape bare dot
else:
tmp.append(pattern[i])
i += 1
regex = "".join(tmp)
# If no anchors are provided, allow substring matching
if not regex.startswith("^") and not regex.startswith(".*"):
regex = ".*" + regex
if not regex.endswith("$") and not regex.endswith(".*"):
regex = regex + ".*"
# Validate regex
try:
re.compile(regex)
except re.error as e:
raise ValueError(f"Invalid regex generated from pattern '{pattern}': {e}")
return regex


def matches_any_regex(layer_name: str, regex_list: List[str], prefix="re:") -> bool:
"""
Check if layer_name matches any regex pattern in regex_list.
"""
for pattern in regex_list:
# Remove 're:' prefix for matching
pat = pattern.removeprefix(prefix)
if re.fullmatch(pat, layer_name):
return True
return False
Loading