Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6ffcf60
try to enable auto_scheme API
wenhuach21 Sep 25, 2025
5d80825
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2025
a4ef495
update a little
wenhuach21 Sep 25, 2025
4173c3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2025
87e9454
update a little
wenhuach21 Sep 25, 2025
f86eedb
Merge branch 'main' into auto_scheme
wenhuach21 Sep 25, 2025
242d1ee
try to refine parse layer config code
wenhuach21 Sep 25, 2025
4fc6b64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 25, 2025
63de904
Merge branch 'main' into auto_scheme
wenhuach21 Sep 26, 2025
bb4d4ca
Merge branch 'main' into auto_scheme
wenhuach21 Sep 26, 2025
7f76db2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
ae8837b
fix
wenhuach21 Sep 26, 2025
44ca92d
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Sep 26, 2025
531224d
fix
wenhuach21 Sep 26, 2025
c9fa408
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
6453200
fix
wenhuach21 Sep 26, 2025
5b2dd60
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Sep 26, 2025
3811010
tmp_change
wenhuach21 Sep 26, 2025
4de7b08
commit
wenhuach21 Sep 26, 2025
a9f0e44
commit
wenhuach21 Sep 26, 2025
59a9f5d
update a little
wenhuach21 Sep 26, 2025
1b7e911
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
e068049
fix
wenhuach21 Sep 26, 2025
1b84bf2
Merge branch 'auto_scheme' of https://github.com/intel/auto-round int…
wenhuach21 Sep 26, 2025
0357c0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2025
7c034bd
Merge branch 'main' into auto_scheme
wenhuach21 Sep 26, 2025
602421c
merge autoscheme to scheme
wenhuach21 Sep 26, 2025
091c5ad
refine layer_config code
wenhuach21 Sep 29, 2025
90b6fa1
Merge branch 'main' into auto_scheme
wenhuach21 Sep 29, 2025
f027801
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
c6b78c6
tiny change
wenhuach21 Sep 29, 2025
1b9f24e
tiny fix
wenhuach21 Sep 29, 2025
2c0075a
tmp change
wenhuach21 Sep 29, 2025
97198f0
tmp change
wenhuach21 Sep 29, 2025
27b4b4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
2d3095a
update
wenhuach21 Sep 29, 2025
35a298b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
4a594cd
fix
wenhuach21 Sep 29, 2025
dcd08d6
fix uts, still one left
wenhuach21 Sep 30, 2025
9172264
fix gguf issue
wenhuach21 Sep 30, 2025
1d9e593
Merge branch 'main' into auto_scheme
wenhuach21 Sep 30, 2025
f98092c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
033d1f6
update a little
wenhuach21 Sep 30, 2025
8ae1dfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions auto_round/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.
from auto_round.autoround import AutoRound

# support for old api
from auto_round.autoround import AutoRoundLLM, AutoRoundMLLM, AutoRoundAdam
from auto_round.schemes import QuantizationScheme
from auto_round.schemes import QuantizationScheme, AutoScheme
from auto_round.utils import LazyImport


Expand Down
11 changes: 10 additions & 1 deletion auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, *args, **kwargs):

self.add_argument(
"--scale_dtype",
default="fp16",
default=None,
choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"],
help="scale data type to use for quantization",
)
Expand Down Expand Up @@ -470,6 +470,14 @@ def tune(args):
extra_config.scheme_config = scheme_config
extra_config.mllm_config = mllm_config

layer_config = {}
# from auto_round.auto_schemes.haha import get_mixed_config_layer_config
# layer_config = {}
# best_path = get_mixed_config_layer_config(model_name, target_bits=3)
# for item in best_path:
# layer_config[item[0]] = {}
# layer_config[item[0]]["bits"] = item[1]

autoround: BaseCompressor = AutoRound(
model=model_name,
scheme=scheme,
Expand All @@ -486,6 +494,7 @@ def tune(args):
not_use_best_mse=args.not_use_best_mse,
enable_adam=args.adam,
extra_config=extra_config,
layer_config=layer_config,
)

model_name = args.model.rstrip("/")
Expand Down
39 changes: 39 additions & 0 deletions auto_round/auto_schemes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

AUTO_SCHEMES_ALGS = {}


def register_dtype(names):
"""Class decorator to register a mixed precision algorithm to the registry.

Decorator function used before a Pattern subclass.

Args:
names: A string. Define the export type.

Returns:
cls: The class of register.
"""

def register(alg):
if isinstance(names, (tuple, list)):
for name in names:
AUTO_SCHEMES_ALGS[name] = alg
else:
AUTO_SCHEMES_ALGS[names] = alg

return alg

return register
84 changes: 84 additions & 0 deletions auto_round/auto_schemes/gen_scheme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Union

import torch

from auto_round import AutoScheme
from auto_round.utils import get_layer_features


class GenScheme:
def __init__(
self,
auto_scheme: AutoScheme,
model: torch.nn.Module,
quant_layer_names: Iterable[str],
fixed_layer_scheme: dict[str, dict],
scale_dtype: str = "fp16",
dataset="pile-10k",
):
self.auto_scheme = auto_scheme
self.model = model
self.quant_layer_names = quant_layer_names
self.fixed_layer_scheme = fixed_layer_scheme
self.scale_dtype = scale_dtype
self.dataset = dataset

def _get_min_max_avg_bits(self) -> tuple[float, float]:
pass

# not validate yet
def get_layer_bits(self, layer):
weight = layer.weight
n_param = weight.numel()
weight_bits = getattr(layer, "bits", 16)
group_size = getattr(layer, "group_size", 128)
super_group_size = getattr(layer, "super_group_size", None)
super_weight_bits = getattr(layer, "super_bits", None)

# Main quantization cost
weight_total_bits = weight_bits * n_param
if weight_bits >= 16: # Unquantized layer
return weight_total_bits, 16

in_features, output_features = get_layer_features(layer)
# Determine number of groups
if group_size > 0: # group-wise
n_group = output_features * (in_features + group_size - 1) // group_size
elif group_size == 0: # per-tensor
n_group = 1
elif group_size == -1: # per-channel
n_group = output_features # out_channels
else:
raise ValueError(f"Invalid group_size {group_size}")
aux_total_bits = 0
if not super_group_size:
# Scale and zero point bitwidths
scale_bits = 16
zp_bits = weight_bits if not super_group_size else 32 # default: same as weight_bits
# Overhead from scales and zero points
aux_total_bits = n_group * (scale_bits + zp_bits)

# Double quantization case
if super_group_size:
# Number of super-groups
aux_total_bits += n_group * super_weight_bits * 2 # sclae and min int count
n_super_group = (n_group + super_group_size - 1) // super_group_size
aux_total_bits += n_super_group * 32 * 2 # double quant scale and min_v

total_bits = weight_total_bits + aux_total_bits
avg_bits = total_bits / n_param
return total_bits, avg_bits
21 changes: 21 additions & 0 deletions auto_round/auto_schemes/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def get_total_bits(model, layer_config):
pass


def get_bits(layer):
pass
5 changes: 2 additions & 3 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
MLLMCompressor,
)
from auto_round.logger import deprecated, logger
from auto_round.schemes import QuantizationScheme
from auto_round.schemes import AutoScheme, QuantizationScheme
from auto_round.utils import is_mllm_model


Expand Down Expand Up @@ -63,7 +63,7 @@ def __new__(
cls,
model: Union[torch.nn.Module, str],
tokenizer=None,
scheme: Union[str, dict, QuantizationScheme] = "W4A16",
scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16",
layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
iters: int = 200,
Expand All @@ -77,7 +77,6 @@ def __new__(
seed: int = 42,
# for adam
enable_adam: bool = False,
# for MLLM
extra_config: ExtraConfig = None,
**kwargs,
) -> BaseCompressor:
Expand Down
Loading
Loading