diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 1671d157d..55cc052da 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -40,6 +40,8 @@ def __init__(self, *args, **kwargs): self.add_argument("--eval", action="store_true", help="whether to use eval only mode") + self.add_argument("--sq", action="store_true", help="whether to use smoothquant") + self.add_argument( "--scheme", default="W4A16", @@ -470,6 +472,7 @@ def tune(args): autoround: BaseCompressor = AutoRound( model=model_name, scheme=scheme, + sq=args.sq, dataset=args.dataset, iters=args.iters, seqlen=args.seqlen, diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 4074213a9..86875cdd0 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -64,6 +64,7 @@ def __new__( model: Union[torch.nn.Module, str], tokenizer=None, scheme: Union[str, dict, QuantizationScheme] = "W4A16", + sq: bool = False, layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", iters: int = 200, @@ -159,6 +160,7 @@ def __new__( model=model, tokenizer=tokenizer, scheme=scheme, + sq=sq, layer_config=layer_config, dataset=dataset, iters=iters, diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index abfc47366..04456cc5e 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -130,6 +130,7 @@ def __init__( model: Union[torch.nn.Module, str], tokenizer=None, scheme: Union[str, dict, QuantizationScheme] = "W4A16", + sq: bool = False, layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", iters: int = 200, @@ -385,6 +386,33 @@ def __init__( import habana_frameworks.torch.core as htcore # pylint: disable=E0401 import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] + # sq, for test + if sq: + from auto_round.calib_dataset import get_dataloader + + dataloader = get_dataloader(tokenizer, seqlen, bs=batch_size, nsamples=nsamples) + auto_alpha_args = { + "init_alpha": 0.5, + "alpha_min": 0.1, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "n_samples": 512, ##512 for cuda, 128 for cpu? + # "do_blockwise": True + } + from auto_round.smooth_quant import SmoothQuant + + model = model.to(self.device) + sq = SmoothQuant(model, dataloader, device=model.device, group_size=-1) + model = sq.transform_model( + alpha=0.5, + # alpha="auto", + auto_alpha_args=auto_alpha_args, + folding=True, + op_types=[torch.nn.Linear, torch.nn.Conv2d], + calib_iter=100, + ) + def _set_device(self, device_map): if hasattr(self, "device") and self.device is not None: return diff --git a/auto_round/smooth_quant/__init__.py b/auto_round/smooth_quant/__init__.py new file mode 100644 index 000000000..7af3654f7 --- /dev/null +++ b/auto_round/smooth_quant/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_round.smooth_quant.sq import SmoothQuant diff --git a/auto_round/smooth_quant/absorb_utils.py b/auto_round/smooth_quant/absorb_utils.py new file mode 100644 index 000000000..baf2f283b --- /dev/null +++ b/auto_round/smooth_quant/absorb_utils.py @@ -0,0 +1,449 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from auto_round.smooth_quant.utils import get_module + +SUPPORTED_TORCH_MODULE = [ + "Linear", + "Conv2d", + "ConvTranspose2d", + "LayerNorm", + "BatchNorm2d", + "GroupNorm", + "InstanceNorm2d", + "LlamaRMSNorm", + "T5LayerNorm", + "LPLayerNorm", + "RMSNorm", + "Qwen2RMSNorm", + "WrapperWALayer", +] + +GET_ABSORB_LAYERS = {} + + +def register_absorb_func(model_type): + def register(func): + if isinstance(model_type, list): + model_types = model_type + else: + model_types = [model_type] + for name in model_types: + GET_ABSORB_LAYERS[name] = func + return func + + return register + + +def _check_valid_conv(module): + """Remove group conv except depthwise conv + :param module: + + :return: + """ + if not isinstance(module, torch.nn.Conv2d): + return True + if module.groups > 1: + if module.in_channels == module.out_channels and module.groups == module.in_channels: + return True + else: + return False + return True + + +def remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers): + res = {} + for key in absorb_to_layer.keys(): + absorb_layer = get_module(model, key) + layer_type = absorb_layer.__class__.__name__ + if layer_type not in SUPPORTED_TORCH_MODULE: + no_absorb_layers.extend(absorb_to_layer[key]) + continue + supported = True + for layer_name in absorb_to_layer[key]: + layer = get_module(model, layer_name) + layer_type = layer.__class__.__name__ + if (layer_type not in SUPPORTED_TORCH_MODULE) or not _check_valid_conv(layer): + supported = False + no_absorb_layers.extend(absorb_to_layer[key]) + break + if supported: + res[key] = absorb_to_layer[key] + return res + + +@register_absorb_func("opt") +def get_opt_absorb_layers(model): + model_layer_name = "model.decoder.layers" + absorb_to_layer = {} + for idx in range(len(model.model.decoder.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.self_attn_layer_norm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + # no_absorb_layers.append(f"{model_layer_name}.{idx}.self_attn.out_proj") + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.out_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.final_layer_norm"] = [ + f"{model_layer_name}.{idx}.fc1", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.fc1"] = [ + f"{model_layer_name}.{idx}.fc2", + ] + + # final layer + # absorb_to_layer["model.decoder.final_layer_norm"] = ['lm_head'] + + return absorb_to_layer + + +# @register_absorb_func('llama') +# def get_llama_absorb_layers(model): +# model_layer_name = "model.layers" +# absorb_to_layer = {} + +# for idx in range(len(model.model.layers)): +# # attention input +# absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ +# f"{model_layer_name}.{idx}.self_attn.q_proj", +# f"{model_layer_name}.{idx}.self_attn.k_proj", +# f"{model_layer_name}.{idx}.self_attn.v_proj", +# ] + +# # attention out +# module = model.model.layers[idx] +# if hasattr(module.self_attn.v_proj, "orig_layer"): +# v_proj_shape = module.self_attn.v_proj.orig_layer.weight.shape +# o_proj_shape = module.self_attn.o_proj.orig_layer.weight.shape +# else: +# v_proj_shape = module.self_attn.v_proj.weight.shape +# o_proj_shape = module.self_attn.o_proj.weight.shape +# if v_proj_shape == o_proj_shape: +# absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ +# f"{model_layer_name}.{idx}.self_attn.o_proj", +# ] + +# # linear 1 +# absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ +# f"{model_layer_name}.{idx}.mlp.gate_proj", +# f"{model_layer_name}.{idx}.mlp.up_proj", +# ] + +# # linear 2 +# absorb_to_layer[f"{model_layer_name}.{idx}.mlp.up_proj"] = [ +# f"{model_layer_name}.{idx}.mlp.down_proj", +# ] + +# # final layer +# # absorb_to_layer["model.norm"] = ['lm_head'] + +# return absorb_to_layer + + +@register_absorb_func("mistral") +def get_mistral_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + module = model.model.layers[idx] + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.gate_proj", + f"{model_layer_name}.{idx}.mlp.up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.down_proj", + ] + + # final layer + # absorb_to_layer["model.norm"] = ['lm_head'] + + return absorb_to_layer + + +@register_absorb_func("mixtral") +def get_mixtral_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + module = model.model.layers[idx] + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear in + module = get_module(model, f"{model_layer_name}.{idx}.block_sparse_moe.experts") + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [] + for i in range(len(module)): + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"].extend( + [ + f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w1", + f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w3", + ] + ) + + # linear out + for i in range(len(module)): + absorb_to_layer[f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w3"] = [ + f"{model_layer_name}.{idx}.block_sparse_moe.experts.{i}.w2" + ] + + # final layer + # absorb_to_layer["model.norm"] = ['lm_head'] + return absorb_to_layer + + +@register_absorb_func("bloom") +def get_bloom_absorb_layers(model): + model_layer_name = "transformer.h" + absorb_to_layer = {} + for idx in range(len(model.transformer.h)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attention.query_key_value", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.dense_h_to_4h", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.gelu_impl"] = [ + f"{model_layer_name}.{idx}.mlp.dense_4h_to_h", + ] + + # final layer + # absorb_to_layer["transformer.ln_f"] = ['lm_head'] + + return absorb_to_layer + + +@register_absorb_func("gptj") +def get_gptj_absorb_layers(model): + model_layer_name = "transformer.h" + absorb_to_layer = {} + for idx in range(len(model.transformer.h)): + # attention input + linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.ln_1"] = [ + f"{model_layer_name}.{idx}.attn.q_proj", + f"{model_layer_name}.{idx}.attn.k_proj", + f"{model_layer_name}.{idx}.attn.v_proj", + f"{model_layer_name}.{idx}.mlp.fc_in", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.attn.v_proj"] = [ + f"{model_layer_name}.{idx}.attn.out_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.act"] = [ + f"{model_layer_name}.{idx}.mlp.fc_out", + ] + + # final layer + # absorb_to_layer["transformer.ln_f"] = ['lm_head'] + + return absorb_to_layer + + +@register_absorb_func("phi3") +def get_phi3_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.qkv_proj", + ] + + # attention out + absorb_to_layer[f"{model_layer_name}.{idx}.self_attn.qkv_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.gate_up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.gate_up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.down_proj", + ] + + # final layer + # absorb_to_layer["model.norm"] = ['lm_head'] + + return absorb_to_layer + + +@register_absorb_func("qwen") +def get_qwen_absorb_layers(model): + model_layer_name = "transformer.h" + absorb_to_layer = {} + for idx in range(len(model.transformer.h)): + # attention + absorb_to_layer[f"{model_layer_name}.{idx}.ln_1"] = [f"{model_layer_name}.{idx}.attn.c_attn"] + + # mlp + absorb_to_layer[f"{model_layer_name}.{idx}.ln_2"] = [ + f"{model_layer_name}.{idx}.mlp.w2", + f"{model_layer_name}.{idx}.mlp.w1", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.w1"] = [ + f"{model_layer_name}.{idx}.mlp.c_proj", + ] + + # final layer + # absorb_to_layer["transformer.ln_f"] = ['lm_head'] + + return absorb_to_layer + + +@register_absorb_func(["qwen2", "qwen3"]) +@register_absorb_func("llama") +def get_defualt_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + module = model.model.layers[idx] + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.gate_proj", + f"{model_layer_name}.{idx}.mlp.up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.down_proj", + ] + + # final layer + # absorb_to_layer["model.norm"] = ['lm_head'] + + return absorb_to_layer + + +@register_absorb_func("qwen3_moe") +def get_qwen3_moe_absorb_layers(model): + model_layer_name = "model.layers" + absorb_to_layer = {} + for idx in range(len(model.model.layers)): + # attention input + absorb_to_layer[f"{model_layer_name}.{idx}.input_layernorm"] = [ + f"{model_layer_name}.{idx}.self_attn.q_proj", + f"{model_layer_name}.{idx}.self_attn.k_proj", + f"{model_layer_name}.{idx}.self_attn.v_proj", + ] + + # attention out + module = model.model.layers[idx] + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + absorb_to_layer[f"{model_layer_name}.{idx}.v_proj"] = [ + f"{model_layer_name}.{idx}.self_attn.o_proj", + ] + + if hasattr(module.mlp, "gate"): + # linear in + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.experts.{i}.gate_proj" for i in range(len(module.mlp.experts)) + ] + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"].extend( + [f"{model_layer_name}.{idx}.mlp.experts.{i}.up_proj" for i in range(len(module.mlp.experts))] + ) + breakpoint() + + # linear out + for i in range(len(module.mlp.experts)): + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.experts.{i}.up_proj"] = [ + f"{model_layer_name}.{idx}.mlp.experts.{i}.down_proj", + ] + else: + # linear 1 + absorb_to_layer[f"{model_layer_name}.{idx}.post_attention_layernorm"] = [ + f"{model_layer_name}.{idx}.mlp.gate_proj", + f"{model_layer_name}.{idx}.mlp.up_proj", + ] + + # linear 2 + absorb_to_layer[f"{model_layer_name}.{idx}.mlp.up_proj"] = [f"{model_layer_name}.{idx}.mlp.down_proj"] + + # final layer + # absorb_to_layer["model.norm"] = ['lm_head'] + return absorb_to_layer + + +def get_absorb_layers(model, skip_unsupported_layers=False): + model_type = model.config.model_type + assert model_type in GET_ABSORB_LAYERS, f"Unsupported model type: {model_type}" + absorb_to_layer = GET_ABSORB_LAYERS[model_type](model) + no_absorb_layers = [] + # if skip_unsupported_layers: + # absorb_to_layer = remove_unsupported_layers(model, absorb_to_layer, no_absorb_layers) + return absorb_to_layer, no_absorb_layers diff --git a/auto_round/smooth_quant/auto_alpha.py b/auto_round/smooth_quant/auto_alpha.py new file mode 100644 index 000000000..6859cce44 --- /dev/null +++ b/auto_round/smooth_quant/auto_alpha.py @@ -0,0 +1,736 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import json + +import numpy +import torch +from tqdm import tqdm + +from auto_round.smooth_quant.calibration import Calibration +from auto_round.smooth_quant.utils import ( + WrapperLayer, + cal_scale, + enough_memo_store_scale, + forward_wrapper, + get_module, + mul_scale, + quant_dequant, + reshape_in_channel_to_last, + reshape_scale_as_input, + reshape_scale_as_weight, + set_module, +) + +from .utils import logger + +TUNERS = {} + + +def register_autotune(name): + """Class decorator to register a smoothquant auto-tune subclass. + + :return: the class of register + """ + + def register(auto_tune): + TUNERS[name] = auto_tune + return auto_tune + + return register + + +@register_autotune("version1") +class AutoAlpha: + def __init__( + self, + model, + dataloader, + absorb_to_layer, + op_types, + device, + q_func, + example_inputs, + weight_clip=True, + alpha_min=0.3, + alpha_max=0.7, + alpha_step=0.1, + shared_criterion="mean", + init_alpha=0.5, + folding=False, + do_blockwise=False, + n_samples=32, + calib_iter=100, + group_size=-1, + ): + """Initialize the AutoAlpha tuner with necessary parameters and components.""" + + self.model = model.to("cpu") + self.model.eval() + self.dataloader = dataloader + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.alpha_step = alpha_step + self.shared_criterion = shared_criterion + self.init_alpha = init_alpha + self.loss_type = "blockwise" if do_blockwise else "model_wise" + self.calib_sample_num = n_samples if n_samples else 32 + self.op_types = op_types + self.absorb_to_layer = absorb_to_layer + self.weight_scale_dict = {} + self.q_func = q_func + self.folding = folding + self.example_inputs = example_inputs + self.max_value_info = {} # to record max values for alpha tune + self.weight_clip = weight_clip[0] if isinstance(weight_clip, tuple) else weight_clip + self.input_maxes = {} + self.input_mins = {} + self.input_maxes_abs = {} + self.device = device + self.calib_iter = calib_iter + self.group_size = group_size + + def tune(self): + """The main entry of auto_alpha + :return: Optimal alpha values and scales based on user-defined recipes.""" + calib = Calibration(self.model, self.dataloader, self.q_func, self.device, self.group_size) + self.input_mins, self.input_maxes = calib.calibrate(self.calib_iter, self.op_types) + for key in self.input_mins.keys(): + self.input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + + if not self.folding: + diff_modules = set(self.absorb_to_layer.keys()).difference(self.input_mins.keys()) + for d in diff_modules: + del self.absorb_to_layer[d] + + scale_memo_use = 0 + for key in self.absorb_to_layer: + layer_name = self.absorb_to_layer[key][0] + input_max = self.input_maxes_abs[layer_name] + scale_memo_use += 4 * input_max.shape[0] * len(self.absorb_to_layer[key]) + alpha_space_len = (self.alpha_max - self.alpha_min) / self.alpha_step + 1 + scale_memo_use *= alpha_space_len + self._save_scale = enough_memo_store_scale(self.device, scale_memo_use) + + if self.loss_type == "blockwise": + self.block_names = self.get_blocks() + logger.info("Blockwise auto-tuning will be performed") + module_names = self._get_sq_layer_names() + block_names, self.block_to_module = self.block_names, {} + for block in block_names: + self.block_to_module[block] = [] + for module in module_names: + checked = False + for block in block_names: + if block + "." in module: + self.block_to_module[block].append(module) + checked = True + if not checked: + self.block_to_module[module] = [module] + self.block_names = list(self.block_to_module.keys()) + logger.info(f"Blockwise auto-tuning: {len(self.block_names)} blocks found") + logger.debug(f"Blockwise auto-tuning blocks info: {self.block_to_module}") + return self._auto_tune_alpha_blockwise() + else: + return self._auto_tune_alpha() + + def get_blocks(self): + """Obtain a list of blocks in block-wise tuning mode.""" + block_names = [] + for n, m in self.model.named_modules(): + if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__: + for nn, mm in m.named_children(): + block_name = n + "." + nn + block_names.append(block_name) + break + return block_names + + def _add_blockwise_observer(self, block_modules): + """ + :param block_modules: the block modules which the observer will insert to + :return: + """ + self.blockwise_hook_handles = [] + for key in block_modules.keys(): + hook_func = self._save_blockwise_hook(key) + hook_handle = block_modules[key].register_forward_hook(hook_func) + self.blockwise_hook_handles.append(hook_handle) + + def _save_blockwise_hook(self, name): + """A forward hook to save inputs/outputs of a block + :param name: the block name + :return: A hook function.""" + + def save_blockwise_hook(module, inputs, outputs): + self.block_inputs[name] = inputs[0] + self.block_outputs[name] = outputs[0] + + return save_blockwise_hook + + def _get_all_hook_module_names(self): + """Obtain all the modules that could be hooked based on given op_types.""" + module_names = [] + for n, module in self.model.named_modules(): + if isinstance(module, tuple(self.op_types)): + module_names.append(n) + return module_names + + # def _update_scales_for_auto(self, absorb_scales, weight_scales): + # """Apply activation and weight scales to the model.""" + # for key in self.absorb_to_layer.keys(): + # layer_names = self.absorb_to_layer[key] + # for layer_name in layer_names: + # layer = get_module(self.model, layer_name) + # input_scale = absorb_scales[key] + # weight_scale = weight_scales[layer_name] + # input_scale = reshape_scale_as_input(layer, input_scale) + # weight_scale = reshape_scale_as_weight(layer, weight_scale) + # #layer.update_scale(input_scale, weight_scale) ##FIXME + # layer.update_scale(None, weight_scale) ##FIXME + + def _update_scales_for_auto(self, absorb_scales, weight_scales): + """Apply activation and weight scales to the model.""" + for key in self.absorb_to_layer.keys(): + absorb_layer = get_module(self.model, key) + layer_names = self.absorb_to_layer[key] + if isinstance(absorb_layer, WrapperLayer): + absorb_scale = absorb_scales[key] + absorb_scale = absorb_scale.view(-1, 1) + absorb_layer.update_scale(None, None, absorb_scale) + + for layer_name in layer_names: + layer = get_module(self.model, layer_name) + weight_scale = weight_scales[layer_name] + weight_scale = reshape_scale_as_weight(layer, weight_scale) + layer.update_scale(None, weight_scale) ##FIXME + else: + for layer_name in layer_names: + layer = get_module(self.model, layer_name) + input_scale = absorb_scales[key] + weight_scale = weight_scales[layer_name] + input_scale = reshape_scale_as_input(layer, input_scale) + weight_scale = reshape_scale_as_weight(layer, weight_scale) + layer.update_scale(input_scale, weight_scale) ##FIXME + + def _change_qdq_for_auto(self, enable=True): + """Change the option for qdq.""" + module_names = self._get_all_hook_module_names() + for name in module_names: + name = name.split(".orig_layer")[0] + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + if enable: + module.enable_quant() + else: + module.disable_quant() + + def _qdq_model_wrapper_for_auto(self, save_q_input=False): + """Wrapper all the module with qdq + :return:""" + module_names = self._get_all_hook_module_names() + self.to_unwrap_module_names = module_names + for name in module_names: + if name not in self.input_mins: # skip module if it's not used in calibration + continue + module = get_module(self.model, name) + new_module = WrapperLayer( + module, + self.input_mins[name], + self.input_maxes[name], + save_q_input=save_q_input, + group_size=self.group_size, + ) + set_module(self.model, name, new_module) + + def _qdq_model_unwrapper_for_auto(self): + """Unwrapper all the module with qdq + :return:""" + module_names = self.to_unwrap_module_names + for name in module_names: + module = get_module(self.model, name) + if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration + continue + set_module(self.model, name, module.orig_layer) + + def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5): + """Cal the adjust scales + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + weight_scales_info = {} + absorb_scales_info = {} + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + if alpha_tmp < 0: + scale = torch.ones((1), device=self.device) + else: + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + + weight_max_per_channel = torch.max(torch.abs(torch.cat(weights, dim=0)), dim=0)[0] + if self.weight_clip: + weight_max_per_channel = weight_max_per_channel.clamp(min=1e-5) + + if self._save_scale: + if key in self.weight_scale_dict and alpha_tmp in self.weight_scale_dict[key]: + scale = self.weight_scale_dict[key][alpha_tmp] + else: + scale = cal_scale(input_max, weights, alpha_tmp, group_size=self.group_size) + else: + scale = cal_scale(input_max, weights, alpha_tmp, group_size=self.group_size) + + absorb_scales_info[key] = 1.0 / scale + absorb_scales_info[key][scale == 0] = 0 + layer_names = absorb_to_layer[key] + if self._save_scale: + if key not in self.weight_scale_dict: + self.weight_scale_dict[key] = {} + self.weight_scale_dict[key][alpha_tmp] = scale + for layer_name in layer_names: + ##self._scale_layer_weight(layer_name, scale) + weight_scales_info[layer_name] = scale + return absorb_scales_info, weight_scales_info + + def _get_auto_loss(self, output, output_q, loss_type="abs", loss_alpha=1.0): + """Get the loss for auto tuning + :param output: Fp32 output for one layer + :param output_q: Quant output for one layer + :param loss_type: The type of loss + :param loss_alpha: Loss alpha i for mean scale error + :return: A tensor of the loss.""" + if len(output.shape) <= 2: + max_value = torch.max(torch.abs(output)) + else: + output = output.reshape(output.shape[0], -1) + output_q = output_q.reshape(output_q.shape[0], -1) + max_value = torch.max(torch.abs(output), dim=-1).values.unsqueeze(-1) + max_value = torch.clip(max_value, 1e-5) + + # return torch.sum(torch.nn.functional.cosine_similarity(output, output_q, dim=-1)) + output = output / max_value ##FIXME need copy not replace + output_q = output_q / max_value + if loss_type == "abs": + return torch.sum(torch.pow(torch.abs(output - output_q), 0.5)) + else: + return torch.sum((output - output_q) ** 2) + + def _get_sq_layer_names(self): + """Get all the layers that could be smooth quanted + :return: All the sq layer names.""" + ##TODO this may not fit for folding=False + module_names = [] + for key in self.absorb_to_layer: + module_names += self.absorb_to_layer[key] + return module_names + + def _get_best_alpha(self, absorb_to_layer, loss_alphas, shared_criterion): + """Obtain the optimal alpha values based on shared criterion and loss values recorded in auto-tuning step. + + :return: A dict of layerwise alpha values. + """ + + def dict_to_list(dic): + res = [] + for key in dic.keys(): + res.append((key, dic[key])) + return res + + best_alpha = {} + for ln_name in absorb_to_layer.keys(): + layer_names = absorb_to_layer[ln_name] + cur_shared_criterion = shared_criterion + if len(layer_names) == 1: + cur_shared_criterion = "min" + if cur_shared_criterion == "mean": + loss_tmp = {} + for alpha in loss_alphas[layer_names[0]].keys(): + if alpha not in loss_tmp.keys(): + loss_tmp[alpha] = 0 + for layer_name in layer_names: + loss_tmp[alpha] += loss_alphas[layer_name][alpha] + res = dict_to_list(loss_tmp) + res.sort(key=lambda x: x[1]) + + best_alpha[ln_name] = float(res[0][0]) + + elif cur_shared_criterion == "min" or cur_shared_criterion == "max": + tmp_best_alpha = [] + for layer_name in layer_names: + res = dict_to_list(loss_alphas[layer_name]) + res.sort(key=lambda x: x[1]) + tmp_best_alpha.append(float(res[0][0])) + if cur_shared_criterion == "min": + best_alpha[ln_name] = min(tmp_best_alpha) + else: + best_alpha[ln_name] = max(tmp_best_alpha) + + else: + raise NotImplementedError + return best_alpha + + def _get_one_batch_auto_loss(self, input, alpha_space, orig_best_alpha, input_maxes): + """Calculate the losses for all alpha values given an input. + + :return: A dict of op-wise loss values with respect to alpha values. + """ + self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output + + fp32_output = {} + for name in module_names: + module = get_module(self.model, name) + fp32_output[name] = module.output + module.output = None + self._change_qdq_for_auto(enable=True) + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, orig_best_alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + + loss_alphas = {} + for name in module_names: + module = get_module(self.model, name) + loss = self._get_auto_loss(fp32_output[name], module.output) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[name] + key_name = str(cur_alpha) + loss_alphas[name] = {key_name: loss} + # for name in module_names: + # loss_alphas[name]={} + for alpha in alpha_space: + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + for name in module_names: + losses = loss_alphas[name] + if str(alpha) in losses.keys(): + continue + module = get_module(self.model, name) + output = module.q_dq_forward( + module.q_input, module.input_scale, module.weight_scale, module.absorb_scale + ) + loss = self._get_auto_loss(fp32_output[name], output) + loss_alphas[name][str(alpha)] = loss + return loss_alphas + + def _get_one_batch_auto_loss_blockwise(self, input, alpha_space, orig_best_alpha, input_maxes): + """Calculate the losses for all alpha values given an input in blockwise tuning mode. + + :return: A dict of blockwise-wise loss values with respect to alpha values. + """ + self._change_qdq_for_auto(enable=False) + module_names = self._get_sq_layer_names() + + block_modules = {} + for key in self.block_names: + block_modules[key] = get_module(self.model, key) + self._add_blockwise_observer(block_modules) + + forward_wrapper(self.model, input, self.device) ##disable quant and get fp32 output + + fp32_output = {} + for block_name in self.block_names: + fp32_output[block_name] = self.block_outputs[block_name] + self._change_qdq_for_auto(enable=True) + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, orig_best_alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + forward_wrapper(self.model, input, self.device) ##save quant_input + for mod_name in module_names: # save fp32 values + mod = get_module(self.model, mod_name) + if mod_name in self.fp32_output_val: + self.fp32_output_val[mod_name].append(torch.norm(mod.output)) + else: + self.fp32_output_val[mod_name] = [torch.norm(mod.output)] + del mod + + loss_alphas = {} + + for block_name in self.block_names: + block = get_module(self.model, block_name) + loss = self._get_auto_loss(fp32_output[block_name], self.block_outputs[block_name]) + cur_alpha = orig_best_alpha + if isinstance(orig_best_alpha, dict): + cur_alpha = orig_best_alpha[self.block_to_module[block_name][0]] + key_name = str(cur_alpha) + loss_alphas[block_name] = {key_name: loss} + # for name in module_names: + # loss_alphas[name]={} + for alpha in alpha_space: + absorb_input_scales, weight_scales = self._cal_scales(self.absorb_to_layer, input_maxes, alpha) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + + for block_name in self.block_names: + losses = loss_alphas[block_name] + if str(alpha) in losses.keys(): + continue + block = get_module(self.model, block_name) + block_copy = copy.deepcopy(block) + for name in self.block_to_module[block_name]: + if name == block_name and len(self.block_to_module[block_name]) == 1: + module, module_copy = block, block_copy + else: + module = get_module(block, name) + module_copy = copy.deepcopy(module) + if module.weight_scale is not None: + # module_copy.orig_layer.weight *= module.weight_scale + module_copy.orig_layer.weight.data = mul_scale( + module_copy.orig_layer.weight, module.weight_scale, self.group_size + ) + # q_dq_weight = quant_dequant_w_v1(module_copy.orig_layer) + q_dq_weight = quant_dequant(module_copy.orig_layer) + module_copy.orig_layer.weight.data.copy_(q_dq_weight) + module_copy.do_blockwise = True + if not (name == block_name and len(self.block_to_module[block_name]) == 1): + set_module(block_copy, name, module_copy) + try: + output = block_copy(self.block_inputs[block_name])[0] + except: # Llama model decoder_layer forward requires position_id + position_ids = torch.arange(self.block_inputs[block_name].size()[1]) + position_ids = position_ids.view(self.block_inputs[block_name].size()[0], -1) + position_ids = position_ids.to(self.device) + if hasattr(self.model, "rotary_emb"): + position_embeddings = self.model.rotary_emb(self.block_inputs[block_name], position_ids) + else: + position_embeddings = None + output = block_copy( + self.block_inputs[block_name], + position_ids=position_ids, + position_embeddings=position_embeddings, + )[0] + loss = self._get_auto_loss(fp32_output[block_name], output) + loss_alphas[block_name][str(alpha)] = loss + del block_copy # release memory + return loss_alphas + + def opwise_rank(self, loss_alphas, best_alphas): + """Rank the final losses of ops based on their ratio with respect to op output norm. + + :return: + """ + max_op, max_ratio, max_key = "", 0, "" + ratio_info = {} + for key in self.absorb_to_layer: + for op_name in self.absorb_to_layer[key]: + fp32_norm, loss_ = ( + torch.sum(torch.stack(self.fp32_output_val[op_name])), + loss_alphas[op_name][str(best_alphas[key])], + ) + ratio = loss_ / fp32_norm + max_op = op_name if ratio > max_ratio else max_op + max_key = key if ratio > max_ratio else max_key + max_ratio = max(ratio, max_ratio) + ratio_info[op_name] = ratio + logger.debug( + f"final loss: {op_name}: {loss_}; @alpha {best_alphas[key]}; \ + fp32_output norm: {fp32_norm}; ratio: {ratio}" + ) + import operator + + ratio_info = dict(sorted(ratio_info.items(), key=operator.itemgetter(1), reverse=True)) + for key in list(ratio_info.keys()): + logger.debug(f"sorted opname-ratio: {key}: {ratio_info[key]}") + if max_op != "": + logger.debug( + f"max loss: {max_op}: {loss_alphas[max_op][str(best_alphas[max_key])]} @alpha {best_alphas[max_key]}\ + fp32_output norm: {torch.sum(torch.stack(self.fp32_output_val[max_op]))}; ratio: {max_ratio}" + ) + return None + + def default_tune_setup(self): + """Setup default auto-tune settings. + + :return: A dict of op-wise loss values with respect to alpha values. + """ + round_num = max( # Initialize the alpha search space + len(str(self.alpha_min).split(".")[1]), + len(str(self.alpha_max).split(".")[1]), + len(str(self.alpha_step).split(".")[1]), + ) + self.alpha_space = numpy.round( + numpy.arange(self.alpha_min, self.alpha_max + self.alpha_step, self.alpha_step), round_num + ).tolist() + ##wrapper new module + self._qdq_model_wrapper_for_auto(save_q_input=True) + + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, self.init_alpha + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + return absorb_input_scales, weight_scales + + def _auto_tune_alpha(self): + """Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.""" + logger.info("Start alpha tuning") + + absorb_input_scales, weight_scales = self.default_tune_setup() + + total_cnt, tmp_cnt = 0, 0 + alpha_update_iter, tune_cnt = 0, 4 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + multiply_factor = ( + self.calib_sample_num // tune_cnt if self.calib_sample_num >= tune_cnt else self.calib_sample_num + ) + self.fp32_output_val = {} + best_alphas = self.init_alpha + + if not self.dataloader: + logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") + self._qdq_model_unwrapper_for_auto() + return best_alphas + # bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") + pbar = tqdm(range(self.calib_sample_num // self.dataloader.batch_size), desc="auto tune alpha") + for input in self.dataloader: + pbar.update(1) + if isinstance(input, tuple) or isinstance(input, list): + if len(input) == 2: + input, _ = input # Extract input when both input and label are yielded by dataloader. + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + loss_tmp = self._get_one_batch_auto_loss( + input, self.alpha_space, best_alphas_per_module, self.input_maxes_abs + ) + if loss_alphas == {}: + loss_alphas = loss_tmp + else: + for key in loss_alphas.keys(): + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[key][alpha_key] + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, best_alphas + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change + # self.weight_scale_dict = {} + + if total_cnt >= self.calib_sample_num: + break + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Final alpha {key}:{best_alphas[key]}") + + self.opwise_rank(loss_alphas, best_alphas) + self._qdq_model_unwrapper_for_auto() + logger.info("auto tuning done") + + return best_alphas + + def _auto_tune_alpha_blockwise(self): + """Perform blockwise-alpha-tuning to obtain optimal alpha values and adjust parameters accordingly.""" + logger.info("Start block-wise alpha tuning") + self.block_inputs, self.block_outputs = {}, {} + + absorb_input_scales, weight_scales = self.default_tune_setup() + + total_cnt, tmp_cnt = 0, 0 + alpha_update_iter, tune_cnt = 0, 4 + # multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha + multiply_factor = ( + self.calib_sample_num // tune_cnt if self.calib_sample_num >= tune_cnt else self.calib_sample_num + ) + self.fp32_output_val = {} + best_alphas = self.init_alpha + + if not self.dataloader: + logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.") + self._qdq_model_unwrapper_for_auto() + return best_alphas + bar = tqdm(self.dataloader, total=self.calib_sample_num, desc="auto tune alpha") + for input in bar: + if isinstance(input, tuple): # Extract input when both input and label are yielded by dataloader. + input = input[0] + loss_alphas = {} + best_alphas_per_module = best_alphas + if isinstance(best_alphas, dict): + for key in self.absorb_to_layer.keys(): + layer_names = self.absorb_to_layer[key] + for layer_name in layer_names: + best_alphas_per_module[layer_name] = best_alphas_per_module[key] + loss_tmp = self._get_one_batch_auto_loss_blockwise( + input, self.alpha_space, best_alphas_per_module, self.input_maxes_abs + ) + if loss_alphas == {}: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + loss_alphas[key] = loss_tmp[block_name] + else: + for block_name in self.block_names: + for key in self.block_to_module[block_name]: + cur_loss = loss_alphas[key] + for alpha_key in cur_loss.keys(): + cur_loss[alpha_key] += loss_tmp[block_name][alpha_key] + + total_cnt += self.dataloader.batch_size + tmp_cnt += self.dataloader.batch_size + if tmp_cnt // multiply_factor >= 1: + alpha_update_iter += 1 + tmp_cnt = 0 + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}") + absorb_input_scales, weight_scales = self._cal_scales( + self.absorb_to_layer, self.input_maxes_abs, best_alphas + ) + self._update_scales_for_auto(absorb_input_scales, weight_scales) + # does not need to reset the weight_scale_dict, because use the weight of ori_layer, no change + # self.weight_scale_dict = {} + if total_cnt >= self.calib_sample_num: + break + + best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, self.shared_criterion) + for key in best_alphas.keys(): + logger.info(f"Final alpha {key}:{best_alphas[key]}") + + self.opwise_rank(loss_alphas, best_alphas) + self._qdq_model_unwrapper_for_auto() + logger.info("block-wise auto tuning done") + + return best_alphas diff --git a/auto_round/smooth_quant/calibration.py b/auto_round/smooth_quant/calibration.py new file mode 100644 index 000000000..38f12764e --- /dev/null +++ b/auto_round/smooth_quant/calibration.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json + +import torch + +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size + +from .utils import * + + +class Calibration: + def __init__(self, model, dataloder=None, q_func=None, device="cpu", group_size=-1): + self.model = model + self.dataloader = dataloder + self.q_func = q_func + self.device = device + self.group_size = group_size + + @torch.no_grad() + def _save_input_pc_hook(self, name): + """A forward hook to save input max of a module + :param name: the module name + :return: A hook function.""" + + def save_input_hook(module, inputs, outputs): + if hasattr(module, "orig_layer"): + weight = module.orig_layer.weight + else: + weight = module.weight + input = inputs[0] + ##TODO check input channel is correct + if len(weight.shape) == 4: ##conv3d or conv1d not supported now, need better way + input = input.permute(0, 2, 3, 1) + input, orig_shape, pad_len = reshape_pad_tensor_by_group_size(input, self.group_size) + max_tensor = torch.max(input, dim=0)[0] + min_tensor = torch.min(input, dim=0)[0] + if name not in self.input_maxes.keys(): + self.input_mins[name], self.input_maxes[name] = min_tensor, max_tensor + else: + self.input_mins[name] = torch.min(self.input_mins[name], min_tensor) + self.input_maxes[name] = torch.max(self.input_maxes[name], max_tensor) + + return save_input_hook + + @torch.no_grad() + def _add_min_max_observer(self, modules): + """ + :param modules: the modules which the observer will insert to + :return: + """ + self.hook_handles = [] + for key in modules.keys(): + hook_func = self._save_input_pc_hook(key) + hook_handle = modules[key].register_forward_hook(hook_func) + self.hook_handles.append(hook_handle) + + @torch.no_grad() + def _remove_observer(self): + """Remove the observer from the model + :return:""" + for hook_handle in self.hook_handles: + hook_handle.remove() + + @torch.no_grad() + def _dump_min_max(self, calib_iter=100): + """Dump min max per channel information, the min max value will be saved in input_maxes attribute + :param calibration_method: only support min_max currently + :param calib_iter: Sample size for calibration + :return:""" + if self.q_func: + self.q_func(self.model) + else: + assert self.dataloader, "Please set dataloader for calibration." + model_forward(self.model, self.dataloader, calib_iter, self.device) + + @torch.no_grad() + def calibrate(self, calib_iter, op_types=[torch.nn.Conv2d, torch.nn.Linear]): ##TODO transformers.conv1d + """ + :param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer + :param calib_iter: Data size for calibration + :return: A dict that saved the layer name and the channel-wise max value info + """ + ##hook all the module + self.input_mins = {} + self.input_maxes = {} + + hook_modules = {} + for n, module in self.model.named_modules(): + if isinstance(module, tuple(op_types)): + hook_modules[n] = module + + self._add_min_max_observer(hook_modules) + + self._dump_min_max(calib_iter=calib_iter) + self._remove_observer() + return self.input_mins, self.input_maxes diff --git a/auto_round/smooth_quant/sq.py b/auto_round/smooth_quant/sq.py new file mode 100644 index 000000000..74a65a31d --- /dev/null +++ b/auto_round/smooth_quant/sq.py @@ -0,0 +1,524 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy +import torch + +from auto_round.smooth_quant.absorb_utils import get_absorb_layers +from auto_round.smooth_quant.calibration import Calibration +from auto_round.smooth_quant.utils import ( + cal_scale, + get_module, + model_forward_per_sample, + mul_scale, + reshape_in_channel_to_last, + reshape_scale_as_weight, + set_module, +) +from auto_round.utils import logger + + +class SmoothQuant: + def __init__( + self, + model, + dataloader=None, + device="cpu", + dtype=torch.bfloat16, + example_inputs=None, + q_func=None, + traced_model=None, + group_size=-1, + ): + """ + :param model: Torch model :param dataloader: Calibration dataloader :param traced_model: A specific model + shares the same architecture as the model and could be traced by torch.jit. If not supplied, we use model + instead. + """ + self.model = model + assert isinstance(self.model, torch.nn.Module) + self.model.eval() + self.device = device + self.dtype = dtype + self.dataloader = dataloader + self.example_inputs = example_inputs + self.q_func = q_func + self.input_maxes = {} + self.input_mins = {} + self.input_maxes_abs = {} + self.traced_model = traced_model + if self.traced_model is None: + self.traced_model = self.model + self.weight_scale_info = {} + self.absorb_scales_info = {} + self.insert_mul = False + self.allow_absorb = True + self.record_max_info = False + self.max_value_info = {} # to record max values for alpha tune + self.absorb_to_layer = {} + self.weight_max_lb = 1e-5 ##weight max low bound + self.sq_scale_info = {} + self.max_value_info = {} + self.need_calibration = False + self.group_size = group_size + + @torch.no_grad() + def transform_model( + self, + alpha=0.5, + folding=True, + percentile=100, + op_types=[torch.nn.Linear, torch.nn.Conv2d], + scales_per_op=False, + calib_iter=100, + weight_clip=True, + auto_alpha_args={ + "init_alpha": 0.5, + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "n_samples": 32, ##512 for cuda, 128 for cpu? + }, + ): + """The main entry of smooth quant + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, please refer + to the paper for more details + :param folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant + :param percentile: Not supported now + :param op_types: The op typed to be smooth quantized + :param scales_per_op: Not supported now + :param calib_iter: Data size for calibration + :param weight_clip: Whether to clip weight_max when calculating scales. + + :param auto_alpha_args: Hyperparameters used to set the alpha search space in SQ auto-tuning. + By default, the search space is 0.0-1.0 with step_size 0.1. + do_blockwise: Whether to do blockwise auto-tuning. + :param init_alpha: A hyperparameter that is used in SQ auto-tuning; by default it is 0.5. + :return: A FP32 model with the same architecture as the orig model but with different weight which will be + benefit to quantization. + """ + if not isinstance(self.model, torch.nn.Module): + logger.warning("smoothquant is ignored since the model is not a torch module") + return self.model + + if isinstance(alpha, float) and (alpha < 0): + logger.warning("reset alpha to >=0") + alpha = numpy.clip(alpha, 0.0) + + if folding: + self.insert_mul, self.allow_absorb = False, True + else: + self.insert_mul, self.allow_absorb = True, False + self.weight_clip = weight_clip + + self.revert() + self.need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter) + if self.need_calibration: + self.input_mins, self.input_maxes = {}, {} + self.absorb_to_layer = self._parse_absorb_to_layers( + op_types, folding + ) ##need to forward to check modules not used in forward + if len(self.input_mins) != 0: ##this is from _parse_absorb_to_layers, ugly code to support q_func + input_maxes_abs = {} + for key in self.input_mins.keys(): + input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + if self.q_func: + self.need_calibration = False # Avoid double-calibration in fixed-value alpha SQ. + + if self.absorb_to_layer is None: + logger.warning("empty absorb_to_layer, smoothquant is ignored ") + return self.model + example_inputs = self._get_example_input() + if alpha == "auto": ##TODO need to polish later + from auto_round.smooth_quant.auto_alpha import TUNERS + + auto_alpha_version = "version1" + auto_alpha_tuner = TUNERS[auto_alpha_version]( + self.model, + self.dataloader, + self.absorb_to_layer, + op_types=op_types, + device=self.device, + q_func=self.q_func, + folding=folding, + example_inputs=self.example_inputs, + group_size=self.group_size, + calib_iter=self.calib_iter, + **auto_alpha_args, + ) + self.alpha = auto_alpha_tuner.tune() + input_maxes_abs = auto_alpha_tuner.input_maxes_abs + self.input_mins, self.input_maxes = auto_alpha_tuner.input_mins, auto_alpha_tuner.input_maxes + if auto_alpha_tuner.loss_type == "blockwise": + self.block_names = auto_alpha_tuner.block_names + + elif self.need_calibration: + calib = Calibration(self.model, self.dataloader, self.q_func, self.device, self.group_size) + self.input_mins, self.input_maxes = calib.calibrate(calib_iter, op_types) + input_maxes_abs = {} + for key in self.input_mins.keys(): + input_maxes_abs[key] = torch.max(torch.abs(self.input_mins[key]), torch.abs(self.input_maxes[key])) + + if example_inputs is not None: + out_pre_sq = model_forward_per_sample(self.model, example_inputs, self.device) + + if folding: + self._save_scale = False ##TODO remove it later + + if self.record_max_info: + self._export_sq_info(self.absorb_to_layer, input_maxes_abs, self.alpha) + # # max_info is recorded in self.max_value_info + # self._adjust_parameters(self.absorb_to_layer, input_maxes_abs, alpha) + self.model._smoothquant_optimized = False + return self.model + + self.weight_scale_info, self.absorb_scales_info = self._adjust_parameters( + self.absorb_to_layer, input_maxes_abs, self.alpha + ) + self.model._smoothquant_optimized = True + + if example_inputs is not None: + # Check mathematical equivalency + out_post_sq = model_forward_per_sample(self.model, example_inputs, self.device) + if not self.output_is_equal(out_post_sq[0], out_pre_sq[0]): + logger.warning( + "Mathematical equivelancy of Smoothquant is not preserved. " + "Please kindly report this issue to https://github.com/intel/neural-compressor." + ) + else: + logger.warning(" Could not get example input, equivelancy check is skipped") + + return self.model + + @torch.no_grad() + def revert(self): + """Revert the model weights + :return:""" + for key in self.weight_scale_info: + self._scale_layer_weight(key, 1.0 / self.weight_scale_info[key]) + for key in self.absorb_scales_info: + self._absorb_scales(key, 1.0 / self.absorb_scales_info[key]) + self.weight_scale_info = {} ##clear the data + self.absorb_scales_info = {} + + def output_is_equal(self, out1, out2, atol=1e-03): + try: + if isinstance(out1, tuple): + return all(torch.all(torch.isclose(out1[i], out2[i], atol=atol)) for i in range(len(out1))) + elif isinstance(out1, dict): + return all(torch.all(torch.isclose(out1[k], out2[k], atol=atol)) for k in out1.keys()) + elif isinstance(out1, torch.Tensor): + return torch.all(torch.isclose(out1, out2, atol=atol)) + return False + except: + logger.warning( + "Automatically check failed, Please check equivelancy manually " + "between out_pre_sq and out_post_sq if necessary." + ) + return True + + def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5): + """Cal the adjust scales + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_to_input_maxes = {} + for key in absorb_to_layer.keys(): + layer_name = absorb_to_layer[key][0] + absorb_to_input_maxes[key] = input_maxes[layer_name] + + weight_scales_info = {} + absorb_scales_info = {} + for index, key in enumerate(absorb_to_layer.keys()): + alpha_tmp = alpha[key] if isinstance(alpha, dict) else alpha + + input_max = absorb_to_input_maxes[key] + layer_names = absorb_to_layer[key] + weights = [] + for layer_name in layer_names: + weight = reshape_in_channel_to_last(layer_name, self.model) + weights.append(weight) + scale = cal_scale(input_max, weights, alpha_tmp, group_size=self.group_size) + absorb_scales_info[key] = 1.0 / scale + absorb_scales_info[key][scale == 0] = 0 + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + ##self._scale_layer_weight(layer_name, scale) + weight_scales_info[layer_name] = scale + return absorb_scales_info, weight_scales_info + + def _scale_layer_weight(self, layer_name, scale, alpha=0.5, input_minmax=None): ##input channel + """Scale the layer weights at input channel, depthwise conv output channel + :param layer_name: The layer name + :param scale: The scale to be multiplied + :param alpha: alpha for SQLinearWrapper + :param input_minmax: input_minmax for SQLinearWrapper + :return:""" + layer = get_module(self.model, layer_name) + if self.insert_mul: + from .utils import SQLinearWrapper + + layer = get_module(self.model, layer_name) + if isinstance(layer, SQLinearWrapper): + layer._recover_sq_linear() + set_module(self.model, layer_name, layer.sq_linear) ##recover + else: + new_module = SQLinearWrapper(layer, 1.0 / scale, input_minmax, alpha) + set_module(self.model, layer_name, new_module) + elif self.allow_absorb: + scale = reshape_scale_as_weight(layer, scale) + layer.weight.data = mul_scale(layer.weight, scale) + # layer.weight = torch.nn.Parameter(layer.weight * scale) + layer.weight = torch.nn.Parameter(layer.weight) + return scale + + def _absorb_scales(self, layer_name, scale): ##output channel + """Absorb the scale to the layer at output channel + :param layer_name: The module name + :param scale: The scale to be absorbed + :param alpha_key: The alpha passed to SQLinearWrapper + :return:""" + if self.insert_mul or not self.allow_absorb: + return # absorb is updated in SQLinearWrapper in def _scale_layer_weight + + ##if self.allow absorb + layer = get_module(self.model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + if ( + isinstance(layer, torch.nn.BatchNorm2d) + or isinstance(layer, torch.nn.GroupNorm) + or isinstance(layer, torch.nn.InstanceNorm2d) + ): + if layer.affine: + layer.weight.data = mul_scale(layer.weight, scale) + layer.bias.data = mul_scale(layer.bias, scale) + else: + layer.affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + elif isinstance(layer, torch.nn.LayerNorm): + if layer.elementwise_affine: + layer.weight.data = mul_scale(layer.weight, scale) + layer.bias.data = mul_scale(layer.bias, scale) + else: + layer.elementwise_affine = True + weight = torch.ones(layer.num_features, device=self.device, dtype=self.dtype) * scale + layer.weight = torch.nn.Parameter(torch.ones(weight, requires_grad=False)) + bias = torch.zeros(layer.num_features, device=self.device, dtype=self.dtype) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + elif isinstance(layer, torch.nn.Conv2d): + ##the order could not be changed + if hasattr(layer, "bias") and (layer.bias is not None): + # layer.bias *= scale + layer.bias = mul_scale(layer.bias, scale) + scale = scale.view(scale.shape[0], 1, 1, 1) + # layer.weight *= scale + layer.weight.data = mul_scale(layer.weight, scale) + + elif isinstance(layer, torch.nn.Linear): + if hasattr(layer, "bias") and (layer.bias is not None): + # layer.bias *= scale + layer.bias.data = mul_scale(layer.bias, scale) + scale = scale.view(scale.shape[0], 1) + # layer.weight *= scale + layer.weight.data = mul_scale(layer.weight, scale) + + elif layer.__class__.__name__ in ["Qwen2RMSNorm", "Qwen3RMSNorm", "LlamaRMSNorm", "T5LayerNorm"]: + # layer.weight *= scale + layer.weight.data = mul_scale(layer.weight, scale) + + else: + logger.warning_once( + f"found unsupported layer {type(layer)}, try to multiply scale to " + f"weight and bias directly, this may introduce accuracy issue, please have a check " + ) + if hasattr(layer, "weight") and layer.weight is not None: + # layer.weight *= scale + layer.weight.data = mul_scale(layer.weight, scale) + if hasattr(layer, "bias") and layer.bias is not None: + # layer.bias *= scale + layer.bias = mul_scale(layer.bias, scale) + + def _adjust_parameters(self, absorb_to_layer, input_maxes, alpha=0.5): + """Adjust the weights and biases + :param absorb_to_layer: A dict mapping absorb layer to smooth quantized layer + :param input_maxes: The channel-wise input max info for layers + :param alpha: Alpha value to balance the quantization difficulty of activation and weight, a float of a dict + :return:""" + absorb_scales_info, weight_scales_info = self._cal_scales(absorb_to_layer, input_maxes, alpha) + if not absorb_scales_info or not weight_scales_info: + return weight_scales_info, absorb_scales_info + for index, key in enumerate(absorb_to_layer.keys()): + # layer = get_module(self.model, key) + # if 'norm' not in layer.__class__.__name__.lower(): + # continue + if isinstance(alpha, float): + alpha_tmp = alpha + elif isinstance(alpha, dict): + alpha_tmp = alpha[key] + absorb_scale = absorb_scales_info[key] + self._absorb_scales(key, absorb_scale) + layer_names = absorb_to_layer[key] + for layer_name in layer_names: + input_minmax = [self.input_mins[layer_names[0]], self.input_maxes[layer_names[0]]] + self._scale_layer_weight(layer_name, weight_scales_info[layer_name], alpha_tmp, input_minmax) + return weight_scales_info, absorb_scales_info + + def _check_need_calibration(self, alpha, percentile, op_types, scales_per_op, calib_iter): + """ + check need calibration or not + :param alpha: current alpha + :param percentile: current percentile + :param op_types: current op_types + :param scales_per_op: current scales_per_op + :param calib_iter:: current scales_per_op + :return: + """ + need_calib = True + from peft import PeftModel + + is_peft, is_auto = isinstance(self.model, PeftModel), alpha == "auto" + if len(self.input_maxes) == 0: ## the first time + need_calib = True + self.alpha = alpha + self.percentile = percentile + self.op_types = op_types + self.scales_per_op = scales_per_op + self.calib_iter = calib_iter + return False if (is_auto and not is_peft) else need_calib + + if ( + self.percentile == percentile + and self.op_types == op_types + and self.scales_per_op == scales_per_op + and self.calib_iter == calib_iter + ): + if isinstance(alpha, float) or self.alpha == "auto": + need_calib = False + + self.alpha, self.percentile, self.calib_iter = alpha, percentile, calib_iter + self.op_types, self.scales_per_op = op_types, scales_per_op + return need_calib + + def _get_all_layer_names(self, op_types=[torch.nn.Linear]): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + self_absorb_layer: A dict, absorb layer name (itself): layers to be smooth quantized + """ + self_absorb_layer = {} + op_types = [torch.nn.Linear] # TODO: only support SQLinearWrapper + for name, module in self.model.named_modules(): + if isinstance(module, tuple(op_types)): + self_absorb_layer[name] = [name] + return self_absorb_layer + + def _get_example_input(self): + if self.dataloader is None and self.example_inputs is None: + return None + for idx, input in enumerate(self.dataloader): + self.example_inputs = input + break + return self.example_inputs + + @torch.no_grad() + def _parse_absorb_to_layers(self, op_types, folding): + self_absorb_layers = {} + if self.insert_mul: + self_absorb_layers = self._get_all_layer_names(op_types) # TODO: only support linear now. + # fetch modules with the same input + group_modules = self._trace(skip_unsupported_layers=False) + if group_modules is not None: + # use one input for qkv + for k, v in group_modules.items(): + for i in v: + if i in self_absorb_layers: + self_absorb_layers.pop(i) + self_absorb_layers[v[0]] = v + logger.debug(f"self_absorb_layers:{self_absorb_layers}") + if self.allow_absorb: + self.absorb_to_layer, no_absorb_layers = self._trace() + if self.absorb_to_layer is None and no_absorb_layers is None: + return None + + # remove self.self_absorb_layers if it exists in self.absorb_to_layer + for k, v in self.absorb_to_layer.items(): + for i in v: + if i in self_absorb_layers: + self_absorb_layers.pop(i) + self.absorb_to_layer.update(self_absorb_layers) + + if self.absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is ignored." + "If you are using huggingface model," + "you could set torchscript to True " + ) + return None + + # Check if input_maxes match self.absorb_to_layer + # (due to self._get_all_layer_names use layer tree instead of forward_path) + if not folding and self.need_calibration: + if len(self.input_mins) == 0: ##there are some modules not used in forward + calib = Calibration( + self.model, self.dataloader, self.q_func, self.device, group_size=self.group_size + ) ## + input_mins, input_maxes = calib.calibrate( + 1, op_types + ) ##TODO if using qfunc for calibration, it will calibrate twice + # use qfunc to calibrate, the input min could be used for fixed alpha transformation + self.input_mins = input_mins + self.input_maxes = input_maxes + diff_modules = set(self.absorb_to_layer.keys()).difference(input_mins.keys()) + for d in diff_modules: + del self.absorb_to_layer[d] + return self.absorb_to_layer + + def _trace(self, skip_unsupported_layers=True): + """Try the model to find the layers which can be smooth quantized. + + :param op_types: The op types to be smooth quantized + :return: + absorb_to_layer: A dict, absorb layer name:layers to be smooth quantized + no_absorb_layers: A list saving the layers which could not find the absorb layer + """ + + absorb_to_layer, no_absorb_layers = get_absorb_layers(self.traced_model, skip_unsupported_layers) + if not skip_unsupported_layers: + return absorb_to_layer + if absorb_to_layer is None and no_absorb_layers is None: + logger.warning( + "sorry, could not trace the model, smooth quant is skipped." + "If you are using huggingface model," + "you could set torchscript to True " + "when loading the model or set the return_dict to False" + ) + elif absorb_to_layer == {}: + logger.warning("could not find any layer to be absorbed") + else: + to_absorb_cnt = 0 + for key, item in absorb_to_layer.items(): + to_absorb_cnt += len(item) + logger.info( + f" {to_absorb_cnt} out of {to_absorb_cnt + len(no_absorb_layers)} " + f"layers could be absorbed in smooth quant" + ) + return absorb_to_layer, no_absorb_layers diff --git a/auto_round/smooth_quant/utils.py b/auto_round/smooth_quant/utils.py new file mode 100644 index 000000000..93a5ac600 --- /dev/null +++ b/auto_round/smooth_quant/utils.py @@ -0,0 +1,318 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from collections import UserDict, defaultdict + +import torch +from tqdm import tqdm + +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad +from auto_round.utils import logger + + +def get_module(model, key): + """Get module from model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + """ + module = model + name_list = key.split(".") + for name in name_list: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, "sq_linear"): # for peft models + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, "orig_layer"): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + return module + + +def set_module(model, key, new_module): + """Set new module into model by key name. + + Args: + model (torch.nn.Module): original model + key (str): module name to be replaced + new_module (torch.nn.Module): new module to be inserted + """ + module = model + name_list = key.split(".") + for name in name_list[:-1]: + if hasattr(module, name): + module = getattr(module, name) + elif hasattr(module, ("sq_linear")): # for peft models that Linears are contained in Linear + module = getattr(module, "sq_linear") + module = getattr(module, name) + elif hasattr(module, ("orig_layer")): # for peft models and auto alpha + module = getattr(module, "orig_layer") + module = getattr(module, name) + else: + module = module + + if hasattr(module, "sq_linear") and name_list[-1] != "sq_linear": # for peft models + module = getattr(module, "sq_linear") + if hasattr(module, "orig_layer") and name_list[-1] != "orig_layer": # for peft models and auto alpha + module = getattr(module, "orig_layer") + setattr(module, name_list[-1], new_module) + + +def mul_scale(tensor, scale, group_size=-1): + ori_shape = tensor.shape + if len(scale.shape) == 2 and scale.shape[1] == 1: + tensor = tensor.reshape(scale.shape[0], -1) + else: + tensor = tensor.reshape(-1, scale.shape[-1]) + + tensor *= scale + return tensor.reshape(ori_shape) + + +def reshape_scale_as_input(layer, scale): + """Reshape the scale for input feature in channel + :param layer: + + :param scale: + :return: + """ + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + + return scale + + +def reshape_scale_as_weight(layer, scale): + """Reshape the scale for weight input channel, depthwise output channel + :param layer: torch module + :param scale: orig scale + :return: reshaped scale.""" + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + if isinstance(layer, torch.nn.Conv2d) and layer.groups > 1: ##only depthwise conv could hit here + scale = scale.view(scale.shape[0], 1, 1, 1) ##mount on output channel + + elif isinstance(layer, torch.nn.Conv2d): + scale = scale.view(1, scale.shape[0], 1, 1) + + elif isinstance(layer, torch.nn.Linear): + scale = scale.view(1, scale.shape[0]) + return scale + + +def move_input_to_device(input, device=torch.device("cpu")): + if isinstance(input, dict) or isinstance(input, UserDict): + tmp_input = {} + for k, inp in input.items(): + tmp_input[k] = move_input_to_device(inp, device) + input = tmp_input + elif isinstance(input, list) or isinstance(input, tuple): + is_tuple = isinstance(input, tuple) + tmp_input = [] + for inp in input: + tmp_input.append(move_input_to_device(inp, device)) + input = tuple(tmp_input) if is_tuple else tmp_input + elif isinstance(input, torch.Tensor): + input = input.to(device) # pylint: disable=no-member + return input + + +def forward_wrapper(model, input, device=torch.device("cpu")): + try: + model = model.to(device) + input = move_input_to_device(input, device) + except Exception as e: + logger.warning(e) + logger.warning("Please check the input device if the error raised.") + if isinstance(input, dict) or isinstance(input, UserDict): + output = model(**input) + elif isinstance(input, list) or isinstance(input, tuple): + try: + output = model(*input) + except: + output = model(input) + else: + output = model(input) + return output + + +def model_forward_per_sample(model, sample, device): + try: + output = forward_wrapper(model, sample, device) + return output + + except Exception as e: + output = forward_wrapper(model, sample[0], device) + return output + + +def model_forward(model, dataloader, iters, device): + cnt = 0 + pbar = tqdm(dataloader, total=iters) + pbar.set_description("SmoothQuant Calibrating") + for idx, input in enumerate(pbar): + output = forward_wrapper(model, input, device) + cnt += 1 + if iters != -1 and cnt > iters: + break + pbar.close() + + +def cal_scale(input_max_abs, weights, alpha, weight_max_lb=1e-5, group_size=-1): + weights = torch.cat(weights, dim=0) + weights, _, _ = reshape_pad_tensor_by_group_size(weights, group_size) + weight_max = torch.max(torch.abs(weights), dim=0)[0] + weight_max = torch.clip(weight_max, weight_max_lb) + input_power = torch.pow(input_max_abs, alpha) + # logger.debug(f"{max(input_max_abs)}, {min(input_max_abs)}") + weight_power = torch.pow(weight_max, 1 - alpha) + weight_scale = torch.clip(input_power / weight_power, min=1e-5) + weight_scale[input_power == 0] = 1.0 + return weight_scale + + +def reshape_in_channel_to_last(layer_name, model): + """Move the input channel to the last dim + :param layer_name: Layer name + :return: The reshaped weight.""" + layer = get_module(model, layer_name) + if layer.__class__.__name__ == "WrapperLayer": + layer = layer.orig_layer + + weight = layer.weight ##TODO oc*ic, support transposed conv + if len(weight.shape) == 4: + weight = weight.permute(0, 2, 3, 1) + weight = weight.reshape(-1, weight.shape[-1]) + return weight + + +def enough_memo_store_scale(device, need_space): + if device == "cuda": # pragma: no cover + current_gpu_index = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(current_gpu_index).total_memory + used_memory = torch.cuda.memory_allocated(current_gpu_index) + free_space = total_memory - used_memory + else: + import psutil + + free_space = psutil.virtual_memory().free + return free_space >= need_space + + +def quant_dequant(m, num_bits=4, group_size=32, data_type="mx_fp4", sym=True): + from auto_round.data_type.utils import get_quant_func + + # data_type = 'int_asym' + data_type = "mx_fp4" + tensor = m.weight if hasattr(m, "weight") else m + quant_func, data_type = get_quant_func(data_type, num_bits, sym) + # print(quant_func, num_bits) + data_new, scale, zp = quant_func(tensor, bits=num_bits, group_size=group_size, v=0, max_scale=1.0) + return data_new.to(tensor.dtype) + + +class WrapperLayer(torch.nn.Module): + def __init__(self, layer, input_min, input_max, save_q_input=False, group_size=-1): + super(WrapperLayer, self).__init__() + if hasattr(layer, "orig_layer"): + layer = layer.orig_layer + self.add_module("orig_layer", layer) # set orig_layer in get/set_module + self.quant = False + self.q_input = None + self.fp32_output = None + self.input_max = input_max + self.input_min = input_min + self.weight_scale = None + self.input_scale = None + self.absorb_scale = None + self.save_q_input = save_q_input + self.do_blockwise = False + self.group_size = group_size + + def enable_quant(self): + self.quant = True + + def disable_quant(self): + self.quant = False + + def update_scale(self, input_scale, weight_scale, absorb_scale=None): + self.input_scale = input_scale + self.weight_scale = weight_scale + self.absorb_scale = absorb_scale + + ##TODO better tradeoff performance and memory, currently it's too slow + def q_dq_forward(self, x, input_scale, weight_scale, absorb_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if absorb_scale is not None: + ori_shape = layer_copy.weight.shape + layer_copy.weight.data = mul_scale(layer_copy.weight, absorb_scale, group_size=self.group_size) + layer_copy.weight.data = layer_copy.weight.view(ori_shape) + if weight_scale is not None: + ori_shape = layer_copy.weight.shape + # layer_copy.weight *= weight_scale + layer_copy.weight.data = mul_scale(layer_copy.weight, weight_scale, group_size=self.group_size) + layer_copy.weight.data = layer_copy.weight.view(ori_shape) + # q_dq_weight = quant_dequant_w_v1(layer_copy) + q_dq_weight = quant_dequant(layer_copy) + layer_copy.weight.data.copy_(q_dq_weight) + if input_scale is None: + # x = quant_dequant_x_v1(x, self.input_min, self.input_max) + x = quant_dequant(x) + else: + ori_shape = x.shape + # x = input_scale * x + x = mul_scale(x, input_scale) + # x = quant_dequant_x_v1(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + x = quant_dequant(x) ##FIXME + output = layer_copy(x) + return output + + def q_dq_forward_blockwise(self, x, input_scale): + layer_copy = copy.deepcopy(self.orig_layer) + if input_scale is None: + # x = quant_dequant_x_v1(x, self.input_min, self.input_max) + x = quant_dequant(x) + else: + x, orig_shape, pad_len = reshape_pad_tensor_by_group_size(x, self.group_size) + x = input_scale * x + x = revert_tensor_by_pad(x, orig_shape, pad_len) + # x = quant_dequant_x_v1(x, self.input_min * input_scale, self.input_max * input_scale) ##FIXME + x = quant_dequant(x) ##FIXME + output = layer_copy(x) + return output + + def forward(self, x): + if self.quant: + # self.q_input = x * scale ##save the q_input + if self.save_q_input: + self.q_input = x + if not self.do_blockwise: + output = self.q_dq_forward(x, self.input_scale, self.weight_scale, self.absorb_scale) + else: + output = self.q_dq_forward_blockwise(x, self.input_scale) + + else: + output = self.orig_layer(x) + self.output = output + return output