Skip to content

Commit fde779c

Browse files
committed
refactor
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5148300 commit fde779c

File tree

3 files changed

+140
-128
lines changed

3 files changed

+140
-128
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,16 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
223223
# This is because the normal workflow of using the weight's dtype
224224
# will be incorrect as the model weight will be compressed
225225
# Therfore, use the dtype set by the user using the PretrainedModel
226-
scale_dtype = None
226+
force_scale_dtype = None
227227
if status == QuantizationStatus.FROZEN:
228228
if hasattr(model, "dtype"):
229-
scale_dtype = model.dtype
229+
force_scale_dtype = model.dtype
230230

231231
model.apply(
232232
lambda module: initialize_module_for_quantization(
233-
module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype
233+
module,
234+
force_zero_point=force_zero_point_init,
235+
force_scale_dtype=force_scale_dtype,
234236
)
235237
)
236238

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 112 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from compressed_tensors.quantization.quant_args import (
2727
FP8_E4M3_DATA,
2828
ActivationOrdering,
29+
DynamicType,
2930
QuantizationArgs,
3031
QuantizationStrategy,
3132
)
@@ -58,8 +59,8 @@ class KVCacheScaleType(Enum):
5859
def initialize_module_for_quantization(
5960
module: Module,
6061
scheme: Optional[QuantizationScheme] = None,
62+
force_scale_dtype: Optional[torch.dtype] = None,
6163
force_zero_point: bool = True,
62-
scale_dtype: Optional[torch.dtype] = None,
6364
):
6465
"""
6566
attaches appropriate scales, zero points, and observers to a layer
@@ -76,51 +77,58 @@ def initialize_module_for_quantization(
7677
:param scale_dtype: dtype to used for the scales, if overriding the
7778
weight dtype as the scale dtype
7879
"""
79-
# TODO: don't initialize parameters when running decompression
8080
scheme = scheme or getattr(module, "quantization_scheme", None)
8181
if scheme is None:
82-
# no scheme passed and layer not targeted for quantization - skip
8382
return
8483

8584
if is_attention_module(module):
8685
# quantized actions based on calltime status
8786
_initialize_attn_scales(module)
8887

8988
else:
90-
if scheme.input_activations is not None:
91-
_initialize_scale_zero_point(
92-
module,
93-
"input",
94-
scheme.input_activations,
95-
force_zero_point=force_zero_point,
96-
scale_dtype=scale_dtype,
89+
if not isinstance(module, torch.nn.Linear):
90+
_LOGGER.warning(f"Attempting to quantize module of type {type(module)}")
91+
92+
# use weight to determine observed shapes and dtype
93+
if hasattr(module, "weight"):
94+
weight = module.weight
95+
assert isinstance(weight, torch.Tensor)
96+
else:
97+
# Note that a weight is required for both weight and activation
98+
# quantization in order to know the dtype of activation scales
99+
_LOGGER.warning(
100+
f"module type {type(module)} targeted for quantization but "
101+
f"has no attribute weight, skipping quantization for {type(module)}"
97102
)
103+
return
104+
105+
if scheme.input_activations is not None:
106+
base_name = "input"
107+
args = scheme.input_activations
108+
observed_shape = weight.shape[-1:]
109+
observed_dtype = force_scale_dtype or weight.dtype
98110

99111
if scheme.weights is not None:
100-
if hasattr(module, "weight"):
101-
weight_shape = None
102-
if isinstance(module, torch.nn.Linear):
103-
weight_shape = module.weight.shape
104-
_initialize_scale_zero_point(
105-
module,
106-
"weight",
107-
scheme.weights,
108-
weight_shape=weight_shape,
109-
force_zero_point=force_zero_point,
110-
scale_dtype=scale_dtype,
111-
)
112-
else:
113-
_LOGGER.warning(
114-
f"module type {type(module)} targeted for weight quantization but "
115-
"has no attribute weight, skipping weight quantization "
116-
f"for {type(module)}"
117-
)
112+
base_name = "weight"
113+
args = scheme.weights
114+
observed_shape = weight.shape
115+
observed_dtype = force_scale_dtype or weight.dtype
118116

119117
if scheme.output_activations is not None:
120-
if not is_kv_cache_quant_scheme(scheme):
121-
_initialize_scale_zero_point(
122-
module, "output", scheme.output_activations, scale_dtype=scale_dtype
123-
)
118+
base_name = "output"
119+
args = scheme.output_activations
120+
observed_shape = weight.shape[:-1]
121+
observed_dtype = force_scale_dtype or weight.dtype
122+
123+
if not is_kv_cache_quant_scheme(scheme):
124+
_initialize_scale_zero_point(
125+
module,
126+
base_name,
127+
args,
128+
observed_shape=observed_shape,
129+
observed_dtype=observed_dtype,
130+
force_zero_point=force_zero_point,
131+
)
124132

125133
module.quantization_scheme = scheme
126134
module.quantization_status = QuantizationStatus.INITIALIZED
@@ -143,19 +151,21 @@ def _initialize_scale_zero_point(
143151
module: Module,
144152
base_name: str,
145153
quantization_args: QuantizationArgs,
146-
weight_shape: Optional[torch.Size] = None,
154+
observed_shape: torch.Size,
155+
observed_dtype: torch.dtype,
147156
force_zero_point: bool = True,
148-
scale_dtype: Optional[torch.dtype] = None,
149157
):
150-
if quantization_args.dynamic is True:
151-
return
158+
strategy = quantization_args.strategy
159+
dynamic = quantization_args.dynamic
160+
actorder = quantization_args.actorder
161+
device = get_execution_device(module) # avoid performing intialization ops on cpu
152162

153-
# initialize on execution device to avoid performing quantized ops on cpu
154-
device = get_execution_device(module)
163+
# Skip all intialization for fully dynamic quantization
164+
if dynamic is True:
165+
return
155166

156-
# 1. Create global_scales for tensor_group - generates
157-
# a per tensor scale
158-
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
167+
# 0. Create global scale for tensor-group quantization
168+
if strategy == QuantizationStrategy.TENSOR_GROUP:
159169
init_global_scale = Parameter(
160170
torch.empty(1, dtype=torch.float32, device=device),
161171
requires_grad=False,
@@ -164,56 +174,49 @@ def _initialize_scale_zero_point(
164174
module, f"{base_name}_global_scale", init_global_scale
165175
)
166176

167-
# 2. Infer expected scale/zero point shape
168-
if quantization_args.strategy == QuantizationStrategy.TOKEN:
169-
expected_shape = (1, 1)
170-
else:
171-
expected_shape = 1
172-
173-
if base_name == "weight" and weight_shape is not None:
174-
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
175-
# (output_channels, 1) - only for weights
176-
expected_shape = (weight_shape[0], 1)
177-
elif quantization_args.strategy in (
178-
QuantizationStrategy.TENSOR_GROUP,
179-
QuantizationStrategy.GROUP,
180-
):
181-
# GROUP/TENSOR_GROUP for both weights and activations
182-
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
183-
expected_shape = (weight_shape[0], max(num_groups, 1))
184-
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
185-
# For block quantization, scale shape should match number of blocks - only
186-
# for weights
187-
if quantization_args.block_structure is None:
188-
raise ValueError(
189-
"Block quantization requires block_structure to be specified"
190-
)
191-
block_height, block_width = quantization_args.block_structure
192-
rows, cols = weight_shape[-2], weight_shape[-1]
193-
num_rows_blocks = math.ceil(rows / block_height)
194-
num_cols_blocks = math.ceil(cols / block_width)
195-
196-
# Warn if dimensions don't divide evenly
197-
if rows % block_height != 0 or cols % block_width != 0:
198-
warnings.warn(
199-
f"Block quantization: tensor shape {weight_shape} does not divide"
200-
f"evenly by block structure {quantization_args.block_structure}. "
201-
f"Some blocks will be incomplete which may affect quantization"
202-
"quality.",
203-
UserWarning,
204-
)
205-
206-
expected_shape = (num_rows_blocks, num_cols_blocks)
207-
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
208-
warnings.warn(
209-
f"BLOCK quantization not supported for {base_name} activations. "
210-
f"Falling back to tensor-level quantization.",
211-
UserWarning,
212-
)
213-
expected_shape = 1
177+
# Skip scale/zp initialization for locally dynamic quantization
178+
if dynamic == DynamicType.LOCAL:
179+
return
180+
181+
# 1. Infer expected scale/zp shape
182+
if strategy in (QuantizationStrategy.TENSOR, QuantizationStrategy.TOKEN):
183+
expected_shape = (1,)
184+
185+
elif strategy == QuantizationStrategy.CHANNEL:
186+
if len(observed_shape) < 1:
187+
raise ValueError("Channel quant requires at least 1 observed dimension")
188+
189+
expected_shape = (observed_shape[-1], 1)
214190

215-
# 3. Identify quantization scale and zp dtype
216-
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
191+
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
192+
assert quantization_args.group_size is not None
193+
if len(observed_shape) < 1:
194+
raise ValueError("Group quant requires at least 1 observed dimension")
195+
196+
group_size = quantization_args.group_size
197+
num_groups = _strict_divide(observed_shape[-1], group_size, strategy)
198+
expected_shape = (num_groups, group_size)
199+
200+
# initialize activation ordering if applicable
201+
if actorder == ActivationOrdering.GROUP:
202+
init_g_idx = Parameter(
203+
torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int),
204+
requires_grad=False,
205+
)
206+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
207+
208+
elif strategy == QuantizationStrategy.BLOCK:
209+
assert quantization_args.block_structure is not None
210+
if len(observed_shape) < 2:
211+
raise ValueError("Block quant requires at least 2 observed dimensions")
212+
213+
block_structure = quantization_args.block_structure
214+
num_rows = _strict_divide(observed_shape[-2], block_structure[-2], strategy)
215+
num_cols = _strict_divide(observed_shape[-1], block_structure[-1], strategy)
216+
expected_shape = (num_rows, num_cols)
217+
218+
# 2. Identify quantization scale and zp dtype
219+
scale_dtype = observed_dtype
217220

218221
if is_fp4(quantization_args=quantization_args):
219222
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
@@ -229,14 +232,12 @@ def _initialize_scale_zero_point(
229232
scale_dtype = torch.float16
230233
zp_dtype = quantization_args.pytorch_dtype()
231234

232-
# 4. Initializes empty scale, zero point, and g_idx parameters for the module
233-
# do not init scales for quantzation_args.dynamic == DynamicType.local
234-
if not quantization_args.dynamic:
235-
init_scale = Parameter(
236-
torch.empty(expected_shape, dtype=scale_dtype, device=device),
237-
requires_grad=False,
238-
)
239-
register_offload_parameter(module, f"{base_name}_scale", init_scale)
235+
# 3. Initializes scale/zp for the module
236+
init_scale = Parameter(
237+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
238+
requires_grad=False,
239+
)
240+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
240241

241242
if force_zero_point or not quantization_args.symmetric:
242243
init_zero_point = Parameter(
@@ -245,16 +246,6 @@ def _initialize_scale_zero_point(
245246
)
246247
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
247248

248-
# only grouped activation ordering has g_idx
249-
if quantization_args.actorder == ActivationOrdering.GROUP:
250-
g_idx_shape = (weight_shape[1],)
251-
g_idx_dtype = torch.int
252-
init_g_idx = Parameter(
253-
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
254-
requires_grad=False,
255-
)
256-
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
257-
258249

259250
def _initialize_attn_scales(module: Module) -> None:
260251
"""Initlaize k_scale, v_scale for self_attn"""
@@ -276,3 +267,16 @@ def _initialize_attn_scales(module: Module) -> None:
276267
requires_grad=False,
277268
)
278269
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
270+
271+
272+
def _strict_divide(observed: int, divisor: int, strategy: QuantizationStrategy) -> int:
273+
out = observed // divisor
274+
if out * divisor != observed:
275+
raise ValueError(
276+
f"{strategy} quantization strategy requires strict division of "
277+
f"weight/activation size {observed} and group/block size {divisor}. "
278+
"consider reducing the group/block size or ignoring modules with weights "
279+
f"not divisible by {divisor}"
280+
)
281+
282+
return out

src/compressed_tensors/quantization/quant_args.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
262262
actorder = model.actorder
263263
dynamic = model.dynamic
264264
observer = model.observer
265+
block_structure = model.block_structure
265266

266267
# infer strategy
267268
if strategy is None:
@@ -277,23 +278,28 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
277278
"strategy='group' and group_size = -1 for 'channel'"
278279
)
279280

280-
# validate strategy and group
281-
if strategy == QuantizationStrategy.GROUP:
282-
if group_size is None or group_size <= 0:
283-
raise ValueError(
284-
f"strategy {strategy} requires group_size to be "
285-
"set to a positive value"
286-
)
287-
if (
288-
group_size is not None
289-
and group_size > 0
290-
and strategy
291-
not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP)
292-
):
293-
raise ValueError("group_size requires strategy to be set to 'group'")
294-
295-
# validate activation ordering and strategy
296-
if actorder is not None and strategy != QuantizationStrategy.GROUP:
281+
# validate block strategy and structure
282+
has_block_strategy = strategy == QuantizationStrategy.BLOCK
283+
has_block_structure = block_structure is not None
284+
if has_block_strategy != has_block_structure:
285+
raise ValueError(
286+
"Block strategy requires `block_structure`, and vice versa. "
287+
f"Instead got ({strategy}, {block_structure})"
288+
)
289+
290+
# validate group strategy
291+
has_group_strategy = strategy in (
292+
QuantizationStrategy.GROUP,
293+
QuantizationStrategy.TENSOR_GROUP,
294+
)
295+
has_group_size = group_size is not None and group_size > 0
296+
has_actorder = actorder is not None
297+
if has_group_strategy != has_group_size:
298+
raise ValueError(
299+
"Group strategies require `group_size`, and vice versa. "
300+
f"Instead got ({strategy}, {group_size})"
301+
)
302+
if has_actorder and not has_group_strategy:
297303
raise ValueError(
298304
"Must use group quantization strategy in order to apply "
299305
"activation ordering"

0 commit comments

Comments
 (0)