Skip to content

Commit b5dc1e9

Browse files
authored
[Quantization] Refactor initialize for activation shape inference (#476)
* refactor Signed-off-by: Kyle Sayers <[email protected]> * reduce diff Signed-off-by: Kyle Sayers <[email protected]> * reduce diff Signed-off-by: Kyle Sayers <[email protected]> * initialize_qparams Signed-off-by: Kyle Sayers <[email protected]> * simplify activation shape Signed-off-by: Kyle Sayers <[email protected]> * increase num of required observed dims Signed-off-by: Kyle Sayers <[email protected]> * remove attention head Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 2dd1b62 commit b5dc1e9

File tree

2 files changed

+150
-104
lines changed

2 files changed

+150
-104
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 125 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515

1616
import logging
17-
import math
18-
import warnings
19-
from typing import Optional
17+
from typing import Optional, Tuple
2018

2119
import torch
2220
from compressed_tensors.quantization import (
2321
FP8_E4M3_DATA,
2422
ActivationOrdering,
23+
DynamicType,
2524
KVCacheScaleType,
2625
QuantizationArgs,
2726
QuantizationMetadata,
@@ -32,7 +31,11 @@
3231
from compressed_tensors.quantization.lifecycle.forward import (
3332
wrap_module_forward_quantized,
3433
)
35-
from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme
34+
from compressed_tensors.quantization.utils import (
35+
is_fp4,
36+
is_kv_cache_quant_scheme,
37+
strategy_cdiv,
38+
)
3639
from compressed_tensors.utils import (
3740
disable_hf_hook,
3841
get_execution_device,
@@ -44,6 +47,7 @@
4447
__all__ = [
4548
"initialize_module_for_quantization",
4649
"is_attention_module",
50+
"initialize_qparams",
4751
]
4852

4953

@@ -69,10 +73,8 @@ def initialize_module_for_quantization(
6973
:param force_zero_point: whether to force initialization of a zero point for
7074
symmetric quantization
7175
"""
72-
# TODO: don't initialize parameters when running decompression
7376
scheme = scheme or getattr(module, "quantization_scheme", None)
7477
if scheme is None:
75-
# no scheme passed and layer not targeted for quantization - skip
7678
return
7779

7880
QuantizationMetadata.clear_all_qparams(module)
@@ -82,38 +84,52 @@ def initialize_module_for_quantization(
8284
_initialize_attn_scales(module)
8385

8486
else:
87+
if not isinstance(module, torch.nn.Linear):
88+
_LOGGER.warning(f"Attempting to quantize module of type {type(module)}")
89+
90+
# use weight to determine observed shapes and dtype
91+
if hasattr(module, "weight"):
92+
weight = module.weight
93+
assert isinstance(weight, torch.Tensor)
94+
else:
95+
# Note that a weight is required for both weight and activation
96+
# quantization in order to know the dtype of activation scales
97+
_LOGGER.warning(
98+
f"module type {type(module)} targeted for quantization but "
99+
f"has no attribute weight, skipping quantization for {type(module)}"
100+
)
101+
return
102+
85103
if scheme.input_activations is not None:
86-
_initialize_scale_zero_point(
104+
initialize_qparams(
87105
module,
88106
"input",
89107
scheme.input_activations,
108+
observed_shape=weight.shape[-1:],
109+
observed_dtype=weight.dtype,
90110
force_zero_point=force_zero_point,
91111
)
92112

93113
if scheme.weights is not None:
94-
if hasattr(module, "weight"):
95-
weight_shape = None
96-
if isinstance(module, torch.nn.Linear):
97-
weight_shape = module.weight.shape
98-
_initialize_scale_zero_point(
99-
module,
100-
"weight",
101-
scheme.weights,
102-
weight_shape=weight_shape,
103-
force_zero_point=force_zero_point,
104-
)
105-
else:
106-
_LOGGER.warning(
107-
f"module type {type(module)} targeted for weight quantization but "
108-
"has no attribute weight, skipping weight quantization "
109-
f"for {type(module)}"
110-
)
111-
112-
if scheme.output_activations is not None:
113-
if not is_kv_cache_quant_scheme(scheme):
114-
_initialize_scale_zero_point(
115-
module, "output", scheme.output_activations
116-
)
114+
initialize_qparams(
115+
module,
116+
"weight",
117+
scheme.weights,
118+
observed_shape=weight.shape,
119+
observed_dtype=weight.dtype,
120+
force_zero_point=force_zero_point,
121+
)
122+
123+
output_is_kv_cache = is_kv_cache_quant_scheme(scheme)
124+
if scheme.output_activations is not None and not output_is_kv_cache:
125+
initialize_qparams(
126+
module,
127+
"output",
128+
scheme.output_activations,
129+
observed_shape=weight.shape[:-1],
130+
observed_dtype=weight.dtype,
131+
force_zero_point=force_zero_point,
132+
)
117133

118134
module.quantization_scheme = scheme
119135
module.quantization_status = QuantizationStatus.INITIALIZED
@@ -132,22 +148,40 @@ def is_attention_module(module: Module):
132148
)
133149

134150

135-
def _initialize_scale_zero_point(
151+
def initialize_qparams(
136152
module: Module,
137153
base_name: str,
138154
quantization_args: QuantizationArgs,
139-
weight_shape: Optional[torch.Size] = None,
155+
observed_shape: Tuple[int],
156+
observed_dtype: torch.dtype,
140157
force_zero_point: bool = True,
141158
):
142-
if quantization_args.dynamic is True:
143-
return
159+
"""
160+
Initialize quantization parameters for a given basename according to the passed
161+
quantization args. The shape and dtype of the observed weight/activation must also
162+
be provided.
163+
164+
Scales will always be initialized. Global scales are initialized depending on args.
165+
Zero points will be initialized if not symmetric or if `force_zero_point` is True.
166+
167+
:param module: module to register qparams to
168+
:param base_name: base name of qparams, for example "input", "weight", "k", "v"
169+
:param quantization_args: arguments for quantization
170+
:param observed_shape: last (right-most) known dimensions of the observed weight/act
171+
:param observed_dtype: dtype of the observed weight/actt
172+
:param force_zero_point: force the zero_point parameter to be initialized
173+
"""
174+
strategy = quantization_args.strategy
175+
dynamic = quantization_args.dynamic
176+
actorder = quantization_args.actorder
177+
device = get_execution_device(module) # avoid performing intialization ops on cpu
144178

145-
# initialize on execution device to avoid performing quantized ops on cpu
146-
device = get_execution_device(module)
179+
# Skip all intialization for fully dynamic quantization
180+
if dynamic is True:
181+
return
147182

148-
# 1. Create global_scales for tensor_group - generates
149-
# a per tensor scale
150-
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
183+
# 0. Create global scale for tensor-group quantization
184+
if strategy == QuantizationStrategy.TENSOR_GROUP:
151185
init_global_scale = Parameter(
152186
torch.empty(1, dtype=torch.float32, device=device),
153187
requires_grad=False,
@@ -156,56 +190,55 @@ def _initialize_scale_zero_point(
156190
module, f"{base_name}_global_scale", init_global_scale
157191
)
158192

159-
# 2. Infer expected scale/zero point shape
160-
if quantization_args.strategy == QuantizationStrategy.TOKEN:
193+
# Skip scale/zp initialization for locally dynamic quantization
194+
if dynamic == DynamicType.LOCAL:
195+
return
196+
197+
# 1. Infer expected scale/zp shape
198+
if strategy == QuantizationStrategy.TENSOR:
199+
expected_shape = (1,)
200+
201+
elif strategy == QuantizationStrategy.TOKEN:
161202
expected_shape = (1, 1)
203+
204+
elif strategy == QuantizationStrategy.CHANNEL:
205+
if len(observed_shape) < 2:
206+
raise ValueError("Channel quant requires at least 2 observed dimensions")
207+
208+
expected_shape = (observed_shape[-2], 1)
209+
210+
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
211+
assert quantization_args.group_size is not None
212+
if len(observed_shape) < 1:
213+
raise ValueError("Group quant requires at least 1 observed dimension")
214+
215+
group_size = quantization_args.group_size
216+
num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
217+
expected_shape = (*observed_shape[:-1], num_groups)
218+
219+
# initialize activation ordering if applicable
220+
if actorder == ActivationOrdering.GROUP:
221+
init_g_idx = Parameter(
222+
torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int),
223+
requires_grad=False,
224+
)
225+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
226+
227+
elif strategy == QuantizationStrategy.BLOCK:
228+
assert quantization_args.block_structure is not None
229+
if len(observed_shape) < 2:
230+
raise ValueError("Block quant requires at least 2 observed dimensions")
231+
232+
block_structure = quantization_args.block_structure
233+
num_rows = strategy_cdiv(observed_shape[-2], block_structure[-2], strategy)
234+
num_cols = strategy_cdiv(observed_shape[-1], block_structure[-1], strategy)
235+
expected_shape = (num_rows, num_cols)
236+
162237
else:
163-
expected_shape = 1
164-
165-
if base_name == "weight" and weight_shape is not None:
166-
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
167-
# (output_channels, 1) - only for weights
168-
expected_shape = (weight_shape[0], 1)
169-
elif quantization_args.strategy in (
170-
QuantizationStrategy.TENSOR_GROUP,
171-
QuantizationStrategy.GROUP,
172-
):
173-
# GROUP/TENSOR_GROUP for both weights and activations
174-
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
175-
expected_shape = (weight_shape[0], max(num_groups, 1))
176-
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
177-
# For block quantization, scale shape should match number of blocks - only
178-
# for weights
179-
if quantization_args.block_structure is None:
180-
raise ValueError(
181-
"Block quantization requires block_structure to be specified"
182-
)
183-
block_height, block_width = quantization_args.block_structure
184-
rows, cols = weight_shape[-2], weight_shape[-1]
185-
num_rows_blocks = math.ceil(rows / block_height)
186-
num_cols_blocks = math.ceil(cols / block_width)
187-
188-
# Warn if dimensions don't divide evenly
189-
if rows % block_height != 0 or cols % block_width != 0:
190-
warnings.warn(
191-
f"Block quantization: tensor shape {weight_shape} does not divide"
192-
f"evenly by block structure {quantization_args.block_structure}. "
193-
f"Some blocks will be incomplete which may affect quantization"
194-
"quality.",
195-
UserWarning,
196-
)
197-
198-
expected_shape = (num_rows_blocks, num_cols_blocks)
199-
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
200-
warnings.warn(
201-
f"BLOCK quantization not supported for {base_name} activations. "
202-
f"Falling back to tensor-level quantization.",
203-
UserWarning,
204-
)
205-
expected_shape = 1
238+
assert False, f"Unknown strategy {strategy}"
206239

207-
# 3. Identify quantization scale and zp dtype
208-
scale_dtype = module.weight.dtype
240+
# 2. Identify quantization scale and zp dtype
241+
scale_dtype = observed_dtype
209242

210243
if is_fp4(quantization_args=quantization_args):
211244
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
@@ -221,14 +254,12 @@ def _initialize_scale_zero_point(
221254
scale_dtype = torch.bfloat16
222255
zp_dtype = quantization_args.pytorch_dtype()
223256

224-
# 4. Initializes empty scale, zero point, and g_idx parameters for the module
225-
# do not init scales for quantzation_args.dynamic == DynamicType.local
226-
if not quantization_args.dynamic:
227-
init_scale = Parameter(
228-
torch.empty(expected_shape, dtype=scale_dtype, device=device),
229-
requires_grad=False,
230-
)
231-
register_offload_parameter(module, f"{base_name}_scale", init_scale)
257+
# 3. Initializes scale/zp for the module
258+
init_scale = Parameter(
259+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
260+
requires_grad=False,
261+
)
262+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
232263

233264
if force_zero_point or not quantization_args.symmetric:
234265
init_zero_point = Parameter(
@@ -237,16 +268,6 @@ def _initialize_scale_zero_point(
237268
)
238269
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
239270

240-
# only grouped activation ordering has g_idx
241-
if quantization_args.actorder == ActivationOrdering.GROUP:
242-
g_idx_shape = (weight_shape[1],)
243-
g_idx_dtype = torch.int
244-
init_g_idx = Parameter(
245-
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
246-
requires_grad=False,
247-
)
248-
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
249-
250271

251272
def _initialize_attn_scales(module: Module) -> None:
252273
"""Initlaize k_scale, v_scale for self_attn"""

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
2929
from compressed_tensors.utils import deprecated
30+
from loguru import logger
3031
from torch import FloatTensor, IntTensor, Tensor
3132
from torch.nn import Module
3233

@@ -47,6 +48,7 @@
4748
"calculate_qparams",
4849
"generate_gparam",
4950
"is_fp4",
51+
"strategy_cdiv",
5052
]
5153

5254
# target the self_attn layer
@@ -461,3 +463,26 @@ def generate_gparam(
461463
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
462464
global_scale = scale_data.max * quant_data.max / max_val_pos
463465
return global_scale.to(dtype).reshape([1])
466+
467+
468+
def strategy_cdiv(
469+
value: int,
470+
divisor: int,
471+
strategy: Optional[QuantizationStrategy],
472+
strict: bool = False,
473+
) -> int:
474+
dividend = math.ceil(value / divisor)
475+
if dividend * divisor != value:
476+
message = (
477+
f"{strategy} quantization strategy requires strict division of "
478+
f"weight/activation size {value} and group/block size {divisor}. "
479+
"consider reducing the group/block size or ignoring modules with "
480+
f"weights not divisible by {divisor}"
481+
)
482+
if strict:
483+
raise ValueError(message)
484+
485+
else:
486+
logger.bind(log_once=True).warning(message)
487+
488+
return dividend

0 commit comments

Comments
 (0)