|
| 1 | +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, |
| 10 | +# software distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from abc import abstractmethod |
| 16 | +from typing import Tuple |
| 17 | +from weakref import ref |
| 18 | + |
| 19 | +import torch |
| 20 | +from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy |
| 21 | +from compressed_tensors.quantization.utils import ( |
| 22 | + calculate_qparams, |
| 23 | + generate_gparam, |
| 24 | + strategy_cdiv, |
| 25 | +) |
| 26 | +from compressed_tensors.utils import getattr_chain |
| 27 | + |
| 28 | + |
| 29 | +base_name_to_scheme_field = { |
| 30 | + "q": "input_activations", |
| 31 | + "k": "input_activations", |
| 32 | + "v": "input_activations", |
| 33 | + "input": "input_activations", |
| 34 | + "weight": "weights", |
| 35 | + "output": "output_activations", |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +class ObserverBase(torch.nn.Module): |
| 40 | + def __init__(self, module: torch.nn.Module, base_name: str): |
| 41 | + super().__init__() |
| 42 | + self.parent = ref(module) |
| 43 | + self.base_name = base_name |
| 44 | + |
| 45 | + self.scheme_field = base_name_to_scheme_field[base_name] |
| 46 | + self.args: QuantizationArgs = getattr_chain( |
| 47 | + module, f"quantization_scheme.{self.scheme_field}" |
| 48 | + ) |
| 49 | + |
| 50 | + # used for moving averages and testing |
| 51 | + self.min_vals = None |
| 52 | + self.max_vals = None |
| 53 | + |
| 54 | + @abstractmethod |
| 55 | + def get_min_max(self, observed: torch.Tensor): |
| 56 | + ... |
| 57 | + |
| 58 | + def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 59 | + observed = flatten_for_quantization(observed, self.base_name, self.args) |
| 60 | + |
| 61 | + self.min_vals, self.max_vals = self.get_min_max(observed) |
| 62 | + |
| 63 | + scales, zero_points = calculate_qparams( |
| 64 | + min_vals=self.min_vals, |
| 65 | + max_vals=self.max_vals, |
| 66 | + quantization_args=self.args, |
| 67 | + global_scale=getattr(self.parent(), f"{self.base_name}_global_scale", None), |
| 68 | + ) |
| 69 | + |
| 70 | + return scales, zero_points |
| 71 | + |
| 72 | + def get_global_scale(self, observed: torch.Tensor): |
| 73 | + observed = observed.reshape((1, 1, -1)) # per tensor reshape |
| 74 | + |
| 75 | + min_vals, max_vals = self.get_min_max(observed) |
| 76 | + |
| 77 | + global_scale = generate_gparam(min_vals, max_vals) |
| 78 | + |
| 79 | + return global_scale |
| 80 | + |
| 81 | + |
| 82 | +class MockMinMaxObserver(ObserverBase): |
| 83 | + def __init__(self, module: torch.nn.Module, base_name: str): |
| 84 | + super().__init__(module, base_name) |
| 85 | + |
| 86 | + def get_min_max(self, observed: torch.Tensor): |
| 87 | + min_vals = torch.amin(observed, dim=(0, -1)) |
| 88 | + max_vals = torch.amax(observed, dim=(0, -1)) |
| 89 | + |
| 90 | + return min_vals, max_vals |
| 91 | + |
| 92 | + |
| 93 | +class MockMovingMinMaxObserver(ObserverBase): |
| 94 | + def __init__(self, module: torch.nn.Module, base_name: str): |
| 95 | + super().__init__(module, base_name) |
| 96 | + |
| 97 | + self.averaging_constant = self.args.observer_kwargs.get( |
| 98 | + "averaging_constant", 0.01 |
| 99 | + ) |
| 100 | + |
| 101 | + def get_min_max(self, observed: torch.Tensor): |
| 102 | + min_vals = torch.amin(observed, dim=(0, -1)) |
| 103 | + max_vals = torch.amax(observed, dim=(0, -1)) |
| 104 | + |
| 105 | + if self.min_vals is not None: |
| 106 | + # FUTURE: consider scaling by num observations (first dim) |
| 107 | + # rather than reducing by first dim |
| 108 | + min_vals = torch.lerp(self.min_vals, min_vals, self.averaging_constant) |
| 109 | + max_vals = torch.lerp(self.max_vals, max_vals, self.averaging_constant) |
| 110 | + |
| 111 | + return min_vals, max_vals |
| 112 | + |
| 113 | + |
| 114 | +def flatten_for_quantization( |
| 115 | + value: torch.Tensor, base_name: str, args: QuantizationArgs |
| 116 | +) -> torch.Tensor: |
| 117 | + if base_name == "weight": |
| 118 | + return flatten_weight_for_quantization(value, args) |
| 119 | + elif base_name in ("input", "output"): |
| 120 | + return flatten_activation_for_quantization(value, args) |
| 121 | + elif base_name in ("q", "k", "v"): |
| 122 | + return flatten_attention_for_quantization(value, args) |
| 123 | + else: |
| 124 | + raise ValueError(f"Unknown quantization base name: {base_name}") |
| 125 | + |
| 126 | + |
| 127 | +def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs): |
| 128 | + if args.strategy == QuantizationStrategy.TENSOR: |
| 129 | + # (1, 1, num_weight_elems) |
| 130 | + return value.reshape((1, 1, -1)) |
| 131 | + |
| 132 | + if args.strategy == QuantizationStrategy.TOKEN: |
| 133 | + raise ValueError("Token quantization cannot be applied to weights") |
| 134 | + |
| 135 | + if args.strategy == QuantizationStrategy.CHANNEL: |
| 136 | + # (1, num_rows, 1, num_cols) |
| 137 | + return value.unsqueeze(-2).unsqueeze(0) |
| 138 | + |
| 139 | + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): |
| 140 | + # (1, num_rows, num_groups, group_size) |
| 141 | + return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0) |
| 142 | + |
| 143 | + if args.strategy == QuantizationStrategy.BLOCK: |
| 144 | + # (1, num_block_rows, num_block_cols, block_width * block_height) |
| 145 | + block_height, block_width = args.block_structure |
| 146 | + num_rows, num_cols = value.shape |
| 147 | + num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy) |
| 148 | + num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy) |
| 149 | + return ( |
| 150 | + value.reshape( |
| 151 | + num_block_rows, |
| 152 | + block_height, |
| 153 | + num_block_cols, |
| 154 | + block_width, |
| 155 | + ) |
| 156 | + .transpose(1, 2) |
| 157 | + .flatten(-2, -1) |
| 158 | + .unsqueeze(0) |
| 159 | + ) |
| 160 | + |
| 161 | + if args.strategy == QuantizationStrategy.ATTN_HEAD: |
| 162 | + raise ValueError("attention head quantization cannot be applied to weights") |
| 163 | + |
| 164 | + assert False, f"Unknown strategy {args.strategy}" |
| 165 | + |
| 166 | + |
| 167 | +def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs): |
| 168 | + if args.strategy == QuantizationStrategy.TENSOR: |
| 169 | + # (batch_size * seq_len, 1, hidden_dim) |
| 170 | + return value.reshape((-1, 1, value.size(-1))) |
| 171 | + |
| 172 | + if args.strategy == QuantizationStrategy.TOKEN: |
| 173 | + # (batch_size, seq_len, hidden_dim) |
| 174 | + # warning: token quantization uses `compute_dynamic_scales_and_zp` |
| 175 | + return value.flatten(2, -1) |
| 176 | + |
| 177 | + if args.strategy == QuantizationStrategy.CHANNEL: |
| 178 | + raise ValueError("Channel quantization cannot be applied to activations") |
| 179 | + |
| 180 | + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): |
| 181 | + # (batch_size * seq_len, num_groups, group_size) |
| 182 | + # warning: group activation quantization uses compute_dynamic_scales_and_zp |
| 183 | + return value.flatten(0, 1).unflatten(-1, (-1, args.group_size)) |
| 184 | + |
| 185 | + if args.strategy == QuantizationStrategy.BLOCK: |
| 186 | + raise ValueError("Block quantization cannot be applied to activations") |
| 187 | + |
| 188 | + if args.strategy == QuantizationStrategy.ATTN_HEAD: |
| 189 | + raise ValueError("attention head quantization cannot be applied to linear acts") |
| 190 | + |
| 191 | + assert False, f"Unknown strategy {args.strategy}" |
| 192 | + |
| 193 | + |
| 194 | +def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs): |
| 195 | + if args.strategy == QuantizationStrategy.TENSOR: |
| 196 | + # (batch_size, seq_len, num_heads, head_dim) |
| 197 | + # (batch_size * seq_len, 1, num_heads * head_dim) |
| 198 | + return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2) |
| 199 | + |
| 200 | + if args.strategy == QuantizationStrategy.TOKEN: |
| 201 | + raise ValueError("Token quantization cannot be applied to attention") |
| 202 | + |
| 203 | + if args.strategy == QuantizationStrategy.CHANNEL: |
| 204 | + raise ValueError("Channel quantization cannot be applied to attention") |
| 205 | + |
| 206 | + if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): |
| 207 | + raise ValueError("Group quantization cannot be applied to attention") |
| 208 | + |
| 209 | + if args.strategy == QuantizationStrategy.BLOCK: |
| 210 | + raise ValueError("Block quantization cannot be applied to attention") |
| 211 | + |
| 212 | + if args.strategy == QuantizationStrategy.ATTN_HEAD: |
| 213 | + # (batch_size * seq_len, num_heads, 1, head_dim) |
| 214 | + return value.flatten(0, 1).unsqueeze(-2) |
| 215 | + |
| 216 | + assert False, f"Unknown strategy {args.strategy}" |
0 commit comments