Skip to content

Commit 5d675fc

Browse files
committed
add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8053b51 commit 5d675fc

File tree

2 files changed

+585
-0
lines changed

2 files changed

+585
-0
lines changed

tests/observer.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import abstractmethod
16+
from typing import Tuple
17+
from weakref import ref
18+
19+
import torch
20+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
21+
from compressed_tensors.quantization.utils import (
22+
calculate_qparams,
23+
generate_gparam,
24+
strategy_cdiv,
25+
)
26+
from compressed_tensors.utils import getattr_chain
27+
28+
29+
base_name_to_scheme_field = {
30+
"q": "input_activations",
31+
"k": "input_activations",
32+
"v": "input_activations",
33+
"input": "input_activations",
34+
"weight": "weights",
35+
"output": "output_activations",
36+
}
37+
38+
39+
class ObserverBase(torch.nn.Module):
40+
def __init__(self, module: torch.nn.Module, base_name: str):
41+
super().__init__()
42+
self.parent = ref(module)
43+
self.base_name = base_name
44+
45+
self.scheme_field = base_name_to_scheme_field[base_name]
46+
self.args: QuantizationArgs = getattr_chain(
47+
module, f"quantization_scheme.{self.scheme_field}"
48+
)
49+
50+
# used for moving averages and testing
51+
self.min_vals = None
52+
self.max_vals = None
53+
54+
@abstractmethod
55+
def get_min_max(self, observed: torch.Tensor):
56+
...
57+
58+
def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
59+
observed = flatten_for_quantization(observed, self.base_name, self.args)
60+
61+
self.min_vals, self.max_vals = self.get_min_max(observed)
62+
63+
scales, zero_points = calculate_qparams(
64+
min_vals=self.min_vals,
65+
max_vals=self.max_vals,
66+
quantization_args=self.args,
67+
global_scale=getattr(self.parent(), f"{self.base_name}_global_scale", None),
68+
)
69+
70+
return scales, zero_points
71+
72+
def get_global_scale(self, observed: torch.Tensor):
73+
observed = observed.reshape((1, 1, -1)) # per tensor reshape
74+
75+
min_vals, max_vals = self.get_min_max(observed)
76+
77+
global_scale = generate_gparam(min_vals, max_vals)
78+
79+
return global_scale
80+
81+
82+
class MockMinMaxObserver(ObserverBase):
83+
def __init__(self, module: torch.nn.Module, base_name: str):
84+
super().__init__(module, base_name)
85+
86+
def get_min_max(self, observed: torch.Tensor):
87+
min_vals = torch.amin(observed, dim=(0, -1))
88+
max_vals = torch.amax(observed, dim=(0, -1))
89+
90+
return min_vals, max_vals
91+
92+
93+
class MockMovingMinMaxObserver(ObserverBase):
94+
def __init__(self, module: torch.nn.Module, base_name: str):
95+
super().__init__(module, base_name)
96+
97+
self.averaging_constant = self.args.observer_kwargs.get(
98+
"averaging_constant", 0.01
99+
)
100+
101+
def get_min_max(self, observed: torch.Tensor):
102+
min_vals = torch.amin(observed, dim=(0, -1))
103+
max_vals = torch.amax(observed, dim=(0, -1))
104+
105+
if self.min_vals is not None:
106+
# FUTURE: consider scaling by num observations (first dim)
107+
# rather than reducing by first dim
108+
min_vals = torch.lerp(self.min_vals, min_vals, self.averaging_constant)
109+
max_vals = torch.lerp(self.max_vals, max_vals, self.averaging_constant)
110+
111+
return min_vals, max_vals
112+
113+
114+
def flatten_for_quantization(
115+
value: torch.Tensor, base_name: str, args: QuantizationArgs
116+
) -> torch.Tensor:
117+
if base_name == "weight":
118+
return flatten_weight_for_quantization(value, args)
119+
elif base_name in ("input", "output"):
120+
return flatten_activation_for_quantization(value, args)
121+
elif base_name in ("q", "k", "v"):
122+
return flatten_attention_for_quantization(value, args)
123+
else:
124+
raise ValueError(f"Unknown quantization base name: {base_name}")
125+
126+
127+
def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
128+
if args.strategy == QuantizationStrategy.TENSOR:
129+
# (1, 1, num_weight_elems)
130+
return value.reshape((1, 1, -1))
131+
132+
if args.strategy == QuantizationStrategy.TOKEN:
133+
raise ValueError("Token quantization cannot be applied to weights")
134+
135+
if args.strategy == QuantizationStrategy.CHANNEL:
136+
# (1, num_rows, 1, num_cols)
137+
return value.unsqueeze(-2).unsqueeze(0)
138+
139+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
140+
# (1, num_rows, num_groups, group_size)
141+
return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0)
142+
143+
if args.strategy == QuantizationStrategy.BLOCK:
144+
# (1, num_block_rows, num_block_cols, block_width * block_height)
145+
block_height, block_width = args.block_structure
146+
num_rows, num_cols = value.shape
147+
num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy)
148+
num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy)
149+
return (
150+
value.reshape(
151+
num_block_rows,
152+
block_height,
153+
num_block_cols,
154+
block_width,
155+
)
156+
.transpose(1, 2)
157+
.flatten(-2, -1)
158+
.unsqueeze(0)
159+
)
160+
161+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
162+
raise ValueError("attention head quantization cannot be applied to weights")
163+
164+
assert False, f"Unknown strategy {args.strategy}"
165+
166+
167+
def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
168+
if args.strategy == QuantizationStrategy.TENSOR:
169+
# (batch_size * seq_len, 1, hidden_dim)
170+
return value.reshape((-1, 1, value.size(-1)))
171+
172+
if args.strategy == QuantizationStrategy.TOKEN:
173+
# (batch_size, seq_len, hidden_dim)
174+
# warning: token quantization uses `compute_dynamic_scales_and_zp`
175+
return value.flatten(2, -1)
176+
177+
if args.strategy == QuantizationStrategy.CHANNEL:
178+
raise ValueError("Channel quantization cannot be applied to activations")
179+
180+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
181+
# (batch_size * seq_len, num_groups, group_size)
182+
# warning: group activation quantization uses compute_dynamic_scales_and_zp
183+
return value.flatten(0, 1).unflatten(-1, (-1, args.group_size))
184+
185+
if args.strategy == QuantizationStrategy.BLOCK:
186+
raise ValueError("Block quantization cannot be applied to activations")
187+
188+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
189+
raise ValueError("attention head quantization cannot be applied to linear acts")
190+
191+
assert False, f"Unknown strategy {args.strategy}"
192+
193+
194+
def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
195+
if args.strategy == QuantizationStrategy.TENSOR:
196+
# (batch_size, seq_len, num_heads, head_dim)
197+
# (batch_size * seq_len, 1, num_heads * head_dim)
198+
return value.flatten(0, 1).flatten(-2, -1).unsqueeze(-2)
199+
200+
if args.strategy == QuantizationStrategy.TOKEN:
201+
raise ValueError("Token quantization cannot be applied to attention")
202+
203+
if args.strategy == QuantizationStrategy.CHANNEL:
204+
raise ValueError("Channel quantization cannot be applied to attention")
205+
206+
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
207+
raise ValueError("Group quantization cannot be applied to attention")
208+
209+
if args.strategy == QuantizationStrategy.BLOCK:
210+
raise ValueError("Block quantization cannot be applied to attention")
211+
212+
if args.strategy == QuantizationStrategy.ATTN_HEAD:
213+
# (batch_size * seq_len, num_heads, 1, head_dim)
214+
return value.flatten(0, 1).unsqueeze(-2)
215+
216+
assert False, f"Unknown strategy {args.strategy}"

0 commit comments

Comments
 (0)