Skip to content

Commit b2abe72

Browse files
authored
[Observer Restructure]: Remove MemoryLess Observer; use helper function for dynamic quantization (#187)
* remove memoryless observer; use helper function for dynamic quantization * update init * clean-up * update test case * fix arg * validation + update name * update preset schemes; swap condition check
1 parent b876a60 commit b2abe72

File tree

8 files changed

+85
-74
lines changed

8 files changed

+85
-74
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
import torch
2020
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
21-
from compressed_tensors.quantization.observers.helpers import calculate_range
21+
from compressed_tensors.quantization.observers.helpers import (
22+
calculate_range,
23+
compute_dynamic_scales_and_zp,
24+
)
2225
from compressed_tensors.quantization.quant_args import (
2326
QuantizationArgs,
2427
QuantizationStrategy,
@@ -376,9 +379,8 @@ def maybe_calibrate_or_quantize(
376379
g_idx = getattr(module, "weight_g_idx", None)
377380

378381
if args.dynamic:
379-
# dynamic quantization - get scale and zero point directly from observer
380-
observer = getattr(module, f"{base_name}_observer")
381-
scale, zero_point = observer(value, g_idx=g_idx)
382+
# dynamic quantization - no need to invoke observer
383+
scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
382384
else:
383385
# static quantization - get previous scale and zero point from layer
384386
scale = getattr(module, f"{base_name}_scale")

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,16 @@ def _initialize_scale_zero_point_observer(
153153
weight_shape: Optional[torch.Size] = None,
154154
force_zero_point: bool = True,
155155
):
156+
156157
# initialize observer module and attach as submodule
157158
observer = quantization_args.get_observer()
158-
module.register_module(f"{base_name}_observer", observer)
159+
# no need to register an observer for dynamic quantization
160+
if observer:
161+
module.register_module(f"{base_name}_observer", observer)
159162

163+
# no need to register a scale and zero point for a dynamic quantization
160164
if quantization_args.dynamic:
161-
return # no need to register a scale and zero point for a dynamic observer
165+
return
162166

163167
device = next(module.parameters()).device
164168
if is_module_offloaded(module):
@@ -173,10 +177,7 @@ def _initialize_scale_zero_point_observer(
173177
expected_shape = (weight_shape[0], 1)
174178
elif quantization_args.strategy == QuantizationStrategy.GROUP:
175179
num_groups = weight_shape[1] // quantization_args.group_size
176-
expected_shape = (
177-
weight_shape[0],
178-
max(num_groups, 1)
179-
)
180+
expected_shape = (weight_shape[0], max(num_groups, 1))
180181

181182
scale_dtype = module.weight.dtype
182183
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:

src/compressed_tensors/quantization/observers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,5 @@
1717

1818
from .helpers import *
1919
from .base import *
20-
from .memoryless import *
2120
from .min_max import *
2221
from .mse import *

src/compressed_tensors/quantization/observers/helpers.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,56 @@
1313
# limitations under the License.
1414

1515
from collections import Counter
16-
from typing import Tuple
16+
from typing import Optional, Tuple
1717

1818
import torch
1919
from compressed_tensors.quantization.quant_args import (
2020
FP8_DTYPE,
2121
QuantizationArgs,
22+
QuantizationStrategy,
2223
QuantizationType,
2324
)
2425
from torch import FloatTensor, IntTensor, Tensor
2526

2627

27-
__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
28+
__all__ = [
29+
"calculate_qparams",
30+
"get_observer_token_count",
31+
"calculate_range",
32+
"compute_dynamic_scales_and_zp",
33+
]
34+
35+
36+
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
37+
"""
38+
Returns the computed scales and zero points for dynamic activation
39+
qunatization.
40+
41+
:param value: tensor to calculate quantization parameters for
42+
:param args: quantization args
43+
:param reduce_dims: optional tuple of dimensions to reduce along,
44+
returned scale and zero point will be shaped (1,) along the
45+
reduced dimensions
46+
:return: tuple of scale and zero point derived from the observed tensor
47+
"""
48+
if args.strategy == QuantizationStrategy.TOKEN:
49+
dim = {1, 2}
50+
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
51+
elif args.strategy == QuantizationStrategy.TENSOR:
52+
reduce_dims = None
53+
else:
54+
raise ValueError(
55+
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
56+
"must be used for dynamic quantization",
57+
)
58+
59+
if not reduce_dims:
60+
min_val, max_val = torch.aminmax(value)
61+
else:
62+
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
63+
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
64+
65+
return calculate_qparams(min_val, max_val, args)
2866

2967

3068
def get_observer_token_count(module: torch.nn.Module) -> Counter:

src/compressed_tensors/quantization/observers/memoryless.py

Lines changed: 0 additions & 56 deletions
This file was deleted.

src/compressed_tensors/quantization/quant_args.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
1516
from enum import Enum
1617
from typing import Any, Dict, Optional, Union
1718

@@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
9495
block_structure: Optional[str] = None
9596
dynamic: bool = False
9697
actorder: Union[ActivationOrdering, bool, None] = None
97-
observer: str = Field(
98+
observer: Optional[str] = Field(
9899
default="minmax",
99100
description=(
100101
"The class to use to compute the quantization param - "
@@ -115,10 +116,10 @@ def get_observer(self):
115116
"""
116117
from compressed_tensors.quantization.observers.base import Observer
117118

119+
# No observer required for the dynamic case
118120
if self.dynamic:
119-
# override defualt observer for dynamic, you never want minmax which
120-
# keeps state across samples for dynamic
121-
self.observer = "memoryless"
121+
self.observer = None
122+
return self.observer
122123

123124
return Observer.load_from_registry(self.observer, quantization_args=self)
124125

@@ -171,6 +172,8 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
171172
strategy = model.strategy
172173
group_size = model.group_size
173174
actorder = model.actorder
175+
dynamic = model.dynamic
176+
observer = model.observer
174177

175178
# infer strategy
176179
if strategy is None:
@@ -207,6 +210,27 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
207210
"activation ordering"
208211
)
209212

213+
if dynamic:
214+
if strategy not in (
215+
QuantizationStrategy.TOKEN,
216+
QuantizationStrategy.TENSOR,
217+
):
218+
raise ValueError(
219+
f"One of {QuantizationStrategy.TOKEN} or "
220+
f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
221+
"quantization",
222+
)
223+
if observer is not None:
224+
warnings.warn(
225+
"No observer is used for dynamic quantization, setting to None"
226+
)
227+
model.observer = None
228+
229+
# if we have not set an observer and we
230+
# are running static quantization, use minmax
231+
if not observer and not dynamic:
232+
model.observer = "minmax"
233+
210234
# write back modified values
211235
model.strategy = strategy
212236
return model

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def is_preset_scheme(name: str) -> bool:
122122
strategy=QuantizationStrategy.TOKEN,
123123
symmetric=True,
124124
dynamic=True,
125+
observer=None,
125126
),
126127
)
127128

@@ -164,6 +165,7 @@ def is_preset_scheme(name: str) -> bool:
164165
strategy=QuantizationStrategy.TOKEN,
165166
symmetric=True,
166167
dynamic=True,
168+
observer=None,
167169
),
168170
)
169171

@@ -200,6 +202,7 @@ def is_preset_scheme(name: str) -> bool:
200202
strategy=QuantizationStrategy.TOKEN,
201203
symmetric=True,
202204
dynamic=True,
205+
observer=None,
203206
),
204207
)
205208

tests/test_quantization/lifecycle/test_dynamic_lifecycle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _test_layer_dynamic_quantization_status(
7373
# check inputs always have an observer if quantized but never scale/zp
7474
assert not hasattr(module, "input_scale")
7575
assert not hasattr(module, "input_zero_point")
76-
assert hasattr(module, "input_observer") == inputs
76+
assert not hasattr(module, "input_observer")
7777

7878
# check weights always have scale/zp and observer only if not frozen
7979
assert hasattr(module, "weight_scale") == weights

0 commit comments

Comments
 (0)