Skip to content

Commit 2fabb8c

Browse files
committed
simplify
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 4ad5c49 commit 2fabb8c

File tree

2 files changed

+13
-60
lines changed

2 files changed

+13
-60
lines changed

tests/observer.py renamed to tests/mock_observer.py

Lines changed: 8 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from abc import abstractmethod
1615
from typing import Tuple
1716
from weakref import ref
1817

@@ -23,37 +22,24 @@
2322
generate_gparam,
2423
strategy_cdiv,
2524
)
26-
from compressed_tensors.utils import getattr_chain
2725

2826

29-
base_name_to_scheme_field = {
30-
"q": "input_activations",
31-
"k": "input_activations",
32-
"v": "input_activations",
33-
"input": "input_activations",
34-
"weight": "weights",
35-
"output": "output_activations",
36-
}
37-
38-
39-
class ObserverBase(torch.nn.Module):
40-
def __init__(self, module: torch.nn.Module, base_name: str):
27+
class MockMinMaxObserver(torch.nn.Module):
28+
def __init__(self, base_name: str, args: QuantizationArgs, module: torch.nn.Module):
4129
super().__init__()
4230
self.parent = ref(module)
4331
self.base_name = base_name
32+
self.args = args
4433

45-
self.scheme_field = base_name_to_scheme_field[base_name]
46-
self.args: QuantizationArgs = getattr_chain(
47-
module, f"quantization_scheme.{self.scheme_field}"
48-
)
49-
50-
# used for moving averages and testing
34+
# used for testing
5135
self.min_vals = None
5236
self.max_vals = None
5337

54-
@abstractmethod
5538
def get_min_max(self, observed: torch.Tensor):
56-
...
39+
min_vals = torch.amin(observed, dim=(0, -1))
40+
max_vals = torch.amax(observed, dim=(0, -1))
41+
42+
return min_vals, max_vals
5743

5844
def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5945
observed = flatten_for_quantization(observed, self.base_name, self.args)
@@ -71,46 +57,12 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
7157

7258
def get_global_scale(self, observed: torch.Tensor):
7359
observed = observed.reshape((1, 1, -1)) # per tensor reshape
74-
7560
min_vals, max_vals = self.get_min_max(observed)
76-
7761
global_scale = generate_gparam(min_vals, max_vals)
7862

7963
return global_scale
8064

8165

82-
class MockMinMaxObserver(ObserverBase):
83-
def __init__(self, module: torch.nn.Module, base_name: str):
84-
super().__init__(module, base_name)
85-
86-
def get_min_max(self, observed: torch.Tensor):
87-
min_vals = torch.amin(observed, dim=(0, -1))
88-
max_vals = torch.amax(observed, dim=(0, -1))
89-
90-
return min_vals, max_vals
91-
92-
93-
class MockMovingMinMaxObserver(ObserverBase):
94-
def __init__(self, module: torch.nn.Module, base_name: str):
95-
super().__init__(module, base_name)
96-
97-
self.averaging_constant = self.args.observer_kwargs.get(
98-
"averaging_constant", 0.01
99-
)
100-
101-
def get_min_max(self, observed: torch.Tensor):
102-
min_vals = torch.amin(observed, dim=(0, -1))
103-
max_vals = torch.amax(observed, dim=(0, -1))
104-
105-
if self.min_vals is not None:
106-
# FUTURE: consider scaling by num observations (first dim)
107-
# rather than reducing by first dim
108-
min_vals = torch.lerp(self.min_vals, min_vals, self.averaging_constant)
109-
max_vals = torch.lerp(self.max_vals, max_vals, self.averaging_constant)
110-
111-
return min_vals, max_vals
112-
113-
11466
def flatten_for_quantization(
11567
value: torch.Tensor, base_name: str, args: QuantizationArgs
11668
) -> torch.Tensor:

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from compressed_tensors.quantization.quant_args import QuantizationArgs
2424
from compressed_tensors.quantization.quant_config import QuantizationStatus
25-
from tests.observer import MockMinMaxObserver
25+
from tests.mock_observer import MockMinMaxObserver
2626

2727

2828
@pytest.mark.parametrize(
@@ -151,7 +151,7 @@ def test_static_weight_quantization(
151151
scheme = QuantizationScheme(targets=[], weights=args)
152152
initialize_module_for_quantization(linear, scheme)
153153
assert getattr(linear, "quantization_scheme") is scheme
154-
linear.weight_observer = MockMinMaxObserver(linear, base_name="weight")
154+
linear.weight_observer = MockMinMaxObserver("weight", args, linear)
155155

156156
# calibrate_global_scale
157157
if hasattr(linear, "weight_global_scale"):
@@ -242,7 +242,7 @@ def test_static_activation_quantization(
242242
scheme = QuantizationScheme(targets=[], input_activations=args)
243243
initialize_module_for_quantization(linear, scheme)
244244
assert getattr(linear, "quantization_scheme") is scheme
245-
linear.input_observer = MockMinMaxObserver(linear, base_name="input")
245+
linear.input_observer = MockMinMaxObserver("input", args, linear)
246246

247247
# calibrate quantization parameters
248248
def calibrate_input_hook(_, args):
@@ -275,6 +275,7 @@ class MockAttention(torch.nn.Module):
275275
pass
276276

277277

278+
@pytest.mark.filterwarnings("ignore::UserWarning")
278279
@pytest.mark.parametrize(
279280
"args,exp_min_val,exp_max_val,exp_quant,exp_loss",
280281
[
@@ -328,7 +329,7 @@ def test_static_attention_quantization(
328329
)
329330
attention.quantization_scheme = scheme
330331
attention.quantization_status = QuantizationStatus.INITIALIZED
331-
attention.k_observer = MockMinMaxObserver(attention, base_name="k")
332+
attention.k_observer = MockMinMaxObserver("k", args, attention)
332333

333334
# calibrate quantization parameters
334335
if scheme.input_activations.dynamic is False:

0 commit comments

Comments
 (0)