|
| 1 | +# Copyright (c) 2025 Red Hat AI, vLLM Project and Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# NOTICE: The design adapted from: |
| 16 | +# https://github.com/vllm-project/compressed-tensors/pull/491 |
| 17 | + |
| 18 | + |
| 19 | +import contextlib |
| 20 | +import inspect |
| 21 | +from functools import partial |
| 22 | +from typing import Callable, Optional |
| 23 | +from weakref import ref |
| 24 | + |
| 25 | +import torch |
| 26 | +from torch import Tensor |
| 27 | +from torch.nn import Module |
| 28 | +from torch.utils.hooks import RemovableHandle |
| 29 | +from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel |
| 30 | +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| 31 | + |
| 32 | +from auto_round.experimental.kv_cache import kvcache_quant_context |
| 33 | +from auto_round.experimental.utils import ( |
| 34 | + is_attention_module, |
| 35 | + per_tensor_fp8_qdq, |
| 36 | + update_parameter_data, |
| 37 | +) |
| 38 | +from auto_round.utils import logger |
| 39 | + |
| 40 | +__all__ = [ |
| 41 | + "QuantizedAttentionImpl", |
| 42 | + "init_hooked_attention", |
| 43 | + "attention_quant_ctx", |
| 44 | +] |
| 45 | + |
| 46 | + |
| 47 | +ATTN_IMPL_ATTR_NAME = "impl" |
| 48 | +HOOKED_ATTENTION_NAME = "ct_hooked_attention" |
| 49 | +QUERY_SCALE_NAME = "q_scale" |
| 50 | +QUERY_MAX_NAME = "q_max" |
| 51 | + |
| 52 | + |
| 53 | +class QuantizedAttentionImpl(torch.nn.Module): |
| 54 | + """ |
| 55 | + QuantizedAttentionImpl module which wraps the functionality of the original |
| 56 | + attention implementation. Unlike the original attention function, this |
| 57 | + implementation is a `torch.nn.Module` which can be hooked to trigger |
| 58 | + transforms and calibration hooks. |
| 59 | +
|
| 60 | + This module works by being registered as a submodule to attention modules via |
| 61 | + `init_hooked_attention`, registering a new attention implementation function |
| 62 | + which calls this module, then setting the model attention implementation to the new |
| 63 | + function. After triggering hooks and quantization, this module calls the original |
| 64 | + attention implementation function. |
| 65 | +
|
| 66 | + :param attn_module: parent attention module |
| 67 | + """ |
| 68 | + |
| 69 | + _original_impl = "sdpa" |
| 70 | + |
| 71 | + def __init__(self, config: PretrainedConfig, attn_module: Module): |
| 72 | + super().__init__() |
| 73 | + self.config = config |
| 74 | + self.attn_module = ref(attn_module) # avoid circular references |
| 75 | + # register query max |
| 76 | + device = next(attn_module.parameters()).device |
| 77 | + initial_max = torch.tensor([float("-inf")], device=device) |
| 78 | + update_parameter_data(attn_module, initial_max, QUERY_MAX_NAME) |
| 79 | + initial_scale = torch.tensor([0.0], device=device) |
| 80 | + update_parameter_data(attn_module, initial_scale, QUERY_SCALE_NAME) |
| 81 | + |
| 82 | + def forward( |
| 83 | + self, |
| 84 | + module: Module, |
| 85 | + query: Tensor, |
| 86 | + key: Tensor, |
| 87 | + value: Tensor, |
| 88 | + *args, |
| 89 | + **kwargs, |
| 90 | + ): |
| 91 | + cur_query_max = query.abs().max() |
| 92 | + query_max = torch.max( |
| 93 | + getattr(module, QUERY_MAX_NAME).data, |
| 94 | + cur_query_max.detach().to(getattr(module, QUERY_MAX_NAME).data.device), |
| 95 | + ) |
| 96 | + update_parameter_data(module, query_max, QUERY_MAX_NAME) |
| 97 | + query, query_scale = per_tensor_fp8_qdq(query, tensor_max=query_max) |
| 98 | + update_parameter_data(module, query_scale.squeeze(0), QUERY_SCALE_NAME) |
| 99 | + # original attention |
| 100 | + return ALL_ATTENTION_FUNCTIONS[self._original_impl]( |
| 101 | + module, |
| 102 | + query, |
| 103 | + key, |
| 104 | + value, |
| 105 | + *args, |
| 106 | + **kwargs, |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +# ----- initialize ----- # |
| 111 | + |
| 112 | + |
| 113 | +def _ct_hooked_attention(module: Module, *args, **kwargs): |
| 114 | + if hasattr(module, ATTN_IMPL_ATTR_NAME): |
| 115 | + return module.impl(module, *args, **kwargs) |
| 116 | + else: |
| 117 | + return ALL_ATTENTION_FUNCTIONS[_original_impl](module, *args, **kwargs) # pylint: disable=E0601 |
| 118 | + |
| 119 | + |
| 120 | +def init_hooked_attention(module: Module, config): |
| 121 | + """ |
| 122 | + Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances |
| 123 | + attached to attention |
| 124 | +
|
| 125 | + :param model: parent model of attention module |
| 126 | + :param module: attention module to initialize with |
| 127 | + """ |
| 128 | + if not hasattr(module, ATTN_IMPL_ATTR_NAME): |
| 129 | + module.register_module(ATTN_IMPL_ATTR_NAME, QuantizedAttentionImpl(config, module)) |
| 130 | + if config._attn_implementation != HOOKED_ATTENTION_NAME: |
| 131 | + # assumes only one model at a time |
| 132 | + global _original_impl |
| 133 | + _original_impl = config._attn_implementation |
| 134 | + # Add new implementation to AttentionInterface(mapping) |
| 135 | + AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention) |
| 136 | + config._attn_implementation = HOOKED_ATTENTION_NAME |
| 137 | + |
| 138 | + # initialize_hooked_kv_cache(model, module) |
| 139 | + |
| 140 | + |
| 141 | +def prep_attention_module_for_calibration(module: torch.nn.Module, config): |
| 142 | + if is_attention_module(module): |
| 143 | + logger.trace(f"Preparing attention module {module.__class__.__name__} for calibration") |
| 144 | + init_hooked_attention(module, config) |
| 145 | + |
| 146 | + |
| 147 | +def clean_up_hooked_attention(module, model): |
| 148 | + if is_attention_module(module): |
| 149 | + # Cleanup phase: Restore the original attention implementation |
| 150 | + if hasattr(model.config, "_attn_implementation") and hasattr(model, "_original_impl"): |
| 151 | + model.config._attn_implementation = model._original_impl |
| 152 | + del model._original_impl |
| 153 | + |
| 154 | + |
| 155 | +@contextlib.contextmanager |
| 156 | +def attention_quant_ctx(model: PreTrainedModel, static_attention_dtype=torch.float8_e4m3fn): |
| 157 | + try: |
| 158 | + # Setup phase: Initialize hooked attention |
| 159 | + prepare_fn = partial(prep_attention_module_for_calibration, config=model.config) |
| 160 | + model.apply(prepare_fn) |
| 161 | + with kvcache_quant_context(model, static_kv_dtype=static_attention_dtype): |
| 162 | + yield model |
| 163 | + finally: |
| 164 | + clean_fn = partial(clean_up_hooked_attention, model=model) |
| 165 | + model.apply(clean_fn) |
0 commit comments