Skip to content

Commit 1d4a39f

Browse files
authored
quatization lifecycle - disable forward pass override + helper for weight quant param updates (#111)
1 parent 0c2d88b commit 1d4a39f

File tree

4 files changed

+116
-0
lines changed

4 files changed

+116
-0
lines changed

src/compressed_tensors/quantization/lifecycle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
from .initialize import *
2222
from .compressed import *
2323
from .apply import *
24+
from .helpers import *

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
245245

246246
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
247247
def wrapped_forward(self, *args, **kwargs):
248+
if not getattr(module, "quantization_enabled", True):
249+
# quantization is disabled on forward passes, return baseline
250+
# forward call
251+
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
252+
248253
input_ = args[0]
249254

250255
if scheme.input_activations is not None:
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Miscelaneous helpers for the quantization lifecycle
17+
"""
18+
19+
20+
from torch.nn import Module
21+
22+
23+
__all__ = [
24+
"update_layer_weight_quant_params",
25+
"enable_quantization",
26+
"disable_quantization",
27+
]
28+
29+
30+
def update_layer_weight_quant_params(layer: Module):
31+
weight = getattr(layer, "weight", None)
32+
scale = getattr(layer, "weight_scale", None)
33+
zero_point = getattr(layer, "weight_zero_point", None)
34+
observer = getattr(layer, "weight_observer", None)
35+
36+
if weight is None or observer is None or scale is None or zero_point is None:
37+
# scale, zp, or observer not calibratable or weight not available
38+
return
39+
40+
updated_scale, updated_zero_point = observer(weight)
41+
42+
# update scale and zero point
43+
device = next(layer.parameters()).device
44+
scale.data = updated_scale.to(device)
45+
zero_point.data = updated_zero_point.to(device)
46+
47+
48+
def enable_quantization(module: Module):
49+
module.quantization_enabled = True
50+
51+
52+
def disable_quantization(module: Module):
53+
module.quantization_enabled = False
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from copy import deepcopy
17+
18+
import torch
19+
from compressed_tensors.quantization import (
20+
QuantizationConfig,
21+
apply_quantization_config,
22+
disable_quantization,
23+
enable_quantization,
24+
)
25+
from torch.nn import Linear
26+
27+
28+
def test_quantization_enabled_disabled():
29+
inp = torch.randn(16)
30+
model = Linear(16, 16)
31+
quantized_model = deepcopy(model)
32+
apply_quantization_config(
33+
model=quantized_model,
34+
config=QuantizationConfig(
35+
config_groups=dict(W4A16=["Linear"]),
36+
quantization_status="calibration",
37+
),
38+
)
39+
40+
# run one calibration pass
41+
quantized_model(inp)
42+
43+
model_output = model(inp)
44+
quantized_model_output = quantized_model(inp)
45+
46+
# quantized and non quantized outputs should be different
47+
assert not torch.all(model_output == quantized_model_output)
48+
49+
# disable quantization
50+
quantized_model.apply(disable_quantization)
51+
# check that quantized model now matches model output
52+
assert torch.all(model_output == quantized_model(inp))
53+
54+
# re-enable quantization
55+
quantized_model.apply(enable_quantization)
56+
# check that quantized model matches original quantized output
57+
assert torch.all(quantized_model_output == quantized_model(inp))

0 commit comments

Comments
 (0)