Skip to content

Commit 1c217e4

Browse files
committed
refactor
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 1016a75 commit 1c217e4

File tree

3 files changed

+139
-120
lines changed

3 files changed

+139
-120
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
221221

222222
model.apply(
223223
lambda module: initialize_module_for_quantization(
224-
module, force_zero_point=force_zero_point_init
224+
module,
225+
force_zero_point=force_zero_point_init,
225226
)
226227
)
227228

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 114 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from compressed_tensors.quantization.quant_args import (
2727
FP8_E4M3_DATA,
2828
ActivationOrdering,
29+
DynamicType,
2930
QuantizationArgs,
3031
QuantizationStrategy,
3132
)
@@ -73,49 +74,58 @@ def initialize_module_for_quantization(
7374
:param force_zero_point: whether to force initialization of a zero point for
7475
symmetric quantization
7576
"""
76-
# TODO: don't initialize parameters when running decompression
7777
scheme = scheme or getattr(module, "quantization_scheme", None)
7878
if scheme is None:
79-
# no scheme passed and layer not targeted for quantization - skip
8079
return
8180

8281
if is_attention_module(module):
8382
# quantized actions based on calltime status
8483
_initialize_attn_scales(module)
8584

8685
else:
87-
if scheme.input_activations is not None:
88-
_initialize_scale_zero_point(
89-
module,
90-
"input",
91-
scheme.input_activations,
92-
force_zero_point=force_zero_point,
86+
if not isinstance(module, torch.nn.Linear):
87+
_LOGGER.warning(f"Attempting to quantize module of type {type(module)}")
88+
89+
# use weight to determine observed shapes and dtype
90+
if hasattr(module, "weight"):
91+
weight = module.weight
92+
assert isinstance(weight, torch.Tensor)
93+
else:
94+
# Note that a weight is required for both weight and activation
95+
# quantization in order to know the dtype of activation scales
96+
_LOGGER.warning(
97+
f"module type {type(module)} targeted for quantization but "
98+
f"has no attribute weight, skipping quantization for {type(module)}"
9399
)
100+
return
101+
102+
if scheme.input_activations is not None:
103+
base_name = "input"
104+
args = scheme.input_activations
105+
observed_shape = weight.shape[-1:]
106+
observed_dtype = weight.dtype
94107

95108
if scheme.weights is not None:
96-
if hasattr(module, "weight"):
97-
weight_shape = None
98-
if isinstance(module, torch.nn.Linear):
99-
weight_shape = module.weight.shape
100-
_initialize_scale_zero_point(
101-
module,
102-
"weight",
103-
scheme.weights,
104-
weight_shape=weight_shape,
105-
force_zero_point=force_zero_point,
106-
)
107-
else:
108-
_LOGGER.warning(
109-
f"module type {type(module)} targeted for weight quantization but "
110-
"has no attribute weight, skipping weight quantization "
111-
f"for {type(module)}"
112-
)
109+
base_name = "weight"
110+
args = scheme.weights
111+
observed_shape = weight.shape
112+
observed_dtype = weight.dtype
113113

114114
if scheme.output_activations is not None:
115-
if not is_kv_cache_quant_scheme(scheme):
116-
_initialize_scale_zero_point(
117-
module, "output", scheme.output_activations
118-
)
115+
base_name = "output"
116+
args = scheme.output_activations
117+
observed_shape = weight.shape[:-1]
118+
observed_dtype = weight.dtype
119+
120+
if not is_kv_cache_quant_scheme(scheme):
121+
_initialize_scale_zero_point(
122+
module,
123+
base_name,
124+
args,
125+
observed_shape=observed_shape,
126+
observed_dtype=observed_dtype,
127+
force_zero_point=force_zero_point,
128+
)
119129

120130
module.quantization_scheme = scheme
121131
module.quantization_status = QuantizationStatus.INITIALIZED
@@ -138,18 +148,21 @@ def _initialize_scale_zero_point(
138148
module: Module,
139149
base_name: str,
140150
quantization_args: QuantizationArgs,
141-
weight_shape: Optional[torch.Size] = None,
151+
observed_shape: torch.Size,
152+
observed_dtype: torch.dtype,
142153
force_zero_point: bool = True,
143154
):
144-
if quantization_args.dynamic is True:
145-
return
155+
strategy = quantization_args.strategy
156+
dynamic = quantization_args.dynamic
157+
actorder = quantization_args.actorder
158+
device = get_execution_device(module) # avoid performing intialization ops on cpu
146159

147-
# initialize on execution device to avoid performing quantized ops on cpu
148-
device = get_execution_device(module)
160+
# Skip all intialization for fully dynamic quantization
161+
if dynamic is True:
162+
return
149163

150-
# 1. Create global_scales for tensor_group - generates
151-
# a per tensor scale
152-
if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
164+
# 0. Create global scale for tensor-group quantization
165+
if strategy == QuantizationStrategy.TENSOR_GROUP:
153166
init_global_scale = Parameter(
154167
torch.empty(1, dtype=torch.float32, device=device),
155168
requires_grad=False,
@@ -158,56 +171,54 @@ def _initialize_scale_zero_point(
158171
module, f"{base_name}_global_scale", init_global_scale
159172
)
160173

161-
# 2. Infer expected scale/zero point shape
162-
if quantization_args.strategy == QuantizationStrategy.TOKEN:
163-
expected_shape = (1, 1)
164-
else:
165-
expected_shape = 1
166-
167-
if base_name == "weight" and weight_shape is not None:
168-
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
169-
# (output_channels, 1) - only for weights
170-
expected_shape = (weight_shape[0], 1)
171-
elif quantization_args.strategy in (
172-
QuantizationStrategy.TENSOR_GROUP,
173-
QuantizationStrategy.GROUP,
174-
):
175-
# GROUP/TENSOR_GROUP for both weights and activations
176-
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
177-
expected_shape = (weight_shape[0], max(num_groups, 1))
178-
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
179-
# For block quantization, scale shape should match number of blocks - only
180-
# for weights
181-
if quantization_args.block_structure is None:
182-
raise ValueError(
183-
"Block quantization requires block_structure to be specified"
184-
)
185-
block_height, block_width = quantization_args.block_structure
186-
rows, cols = weight_shape[-2], weight_shape[-1]
187-
num_rows_blocks = math.ceil(rows / block_height)
188-
num_cols_blocks = math.ceil(cols / block_width)
189-
190-
# Warn if dimensions don't divide evenly
191-
if rows % block_height != 0 or cols % block_width != 0:
192-
warnings.warn(
193-
f"Block quantization: tensor shape {weight_shape} does not divide"
194-
f"evenly by block structure {quantization_args.block_structure}. "
195-
f"Some blocks will be incomplete which may affect quantization"
196-
"quality.",
197-
UserWarning,
198-
)
199-
200-
expected_shape = (num_rows_blocks, num_cols_blocks)
201-
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
202-
warnings.warn(
203-
f"BLOCK quantization not supported for {base_name} activations. "
204-
f"Falling back to tensor-level quantization.",
205-
UserWarning,
206-
)
207-
expected_shape = 1
174+
# Skip scale/zp initialization for locally dynamic quantization
175+
if dynamic == DynamicType.LOCAL:
176+
return
177+
178+
# 1. Infer expected scale/zp shape
179+
if strategy in (QuantizationStrategy.TENSOR, QuantizationStrategy.TOKEN):
180+
expected_shape = (1,)
181+
182+
elif strategy == QuantizationStrategy.CHANNEL:
183+
if len(observed_shape) < 1:
184+
raise ValueError("Channel quant requires at least 1 observed dimension")
208185

186+
expected_shape = (observed_shape[-1], 1)
187+
188+
<<<<<<< HEAD
209189
# 3. Identify quantization scale and zp dtype
210190
scale_dtype = module.weight.dtype
191+
=======
192+
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
193+
assert quantization_args.group_size is not None
194+
if len(observed_shape) < 1:
195+
raise ValueError("Group quant requires at least 1 observed dimension")
196+
197+
group_size = quantization_args.group_size
198+
num_groups = _strict_divide(observed_shape[-1], group_size, strategy)
199+
expected_shape = (num_groups, group_size)
200+
201+
# initialize activation ordering if applicable
202+
if actorder == ActivationOrdering.GROUP:
203+
init_g_idx = Parameter(
204+
torch.full((observed_shape[-1],), -1, device=device, dtype=torch.int),
205+
requires_grad=False,
206+
)
207+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
208+
209+
elif strategy == QuantizationStrategy.BLOCK:
210+
assert quantization_args.block_structure is not None
211+
if len(observed_shape) < 2:
212+
raise ValueError("Block quant requires at least 2 observed dimensions")
213+
214+
block_structure = quantization_args.block_structure
215+
num_rows = _strict_divide(observed_shape[-2], block_structure[-2], strategy)
216+
num_cols = _strict_divide(observed_shape[-1], block_structure[-1], strategy)
217+
expected_shape = (num_rows, num_cols)
218+
219+
# 2. Identify quantization scale and zp dtype
220+
scale_dtype = observed_dtype
221+
>>>>>>> fde779c (refactor)
211222

212223
if is_fp4(quantization_args=quantization_args):
213224
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
@@ -223,14 +234,12 @@ def _initialize_scale_zero_point(
223234
scale_dtype = torch.bfloat16
224235
zp_dtype = quantization_args.pytorch_dtype()
225236

226-
# 4. Initializes empty scale, zero point, and g_idx parameters for the module
227-
# do not init scales for quantzation_args.dynamic == DynamicType.local
228-
if not quantization_args.dynamic:
229-
init_scale = Parameter(
230-
torch.empty(expected_shape, dtype=scale_dtype, device=device),
231-
requires_grad=False,
232-
)
233-
register_offload_parameter(module, f"{base_name}_scale", init_scale)
237+
# 3. Initializes scale/zp for the module
238+
init_scale = Parameter(
239+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
240+
requires_grad=False,
241+
)
242+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
234243

235244
if force_zero_point or not quantization_args.symmetric:
236245
init_zero_point = Parameter(
@@ -239,16 +248,6 @@ def _initialize_scale_zero_point(
239248
)
240249
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
241250

242-
# only grouped activation ordering has g_idx
243-
if quantization_args.actorder == ActivationOrdering.GROUP:
244-
g_idx_shape = (weight_shape[1],)
245-
g_idx_dtype = torch.int
246-
init_g_idx = Parameter(
247-
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
248-
requires_grad=False,
249-
)
250-
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
251-
252251

253252
def _initialize_attn_scales(module: Module) -> None:
254253
"""Initlaize k_scale, v_scale for self_attn"""
@@ -270,3 +269,16 @@ def _initialize_attn_scales(module: Module) -> None:
270269
requires_grad=False,
271270
)
272271
register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
272+
273+
274+
def _strict_divide(observed: int, divisor: int, strategy: QuantizationStrategy) -> int:
275+
out = observed // divisor
276+
if out * divisor != observed:
277+
raise ValueError(
278+
f"{strategy} quantization strategy requires strict division of "
279+
f"weight/activation size {observed} and group/block size {divisor}. "
280+
"consider reducing the group/block size or ignoring modules with weights "
281+
f"not divisible by {divisor}"
282+
)
283+
284+
return out

src/compressed_tensors/quantization/quant_args.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
262262
actorder = model.actorder
263263
dynamic = model.dynamic
264264
observer = model.observer
265+
block_structure = model.block_structure
265266

266267
# infer strategy
267268
if strategy is None:
@@ -277,23 +278,28 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
277278
"strategy='group' and group_size = -1 for 'channel'"
278279
)
279280

280-
# validate strategy and group
281-
if strategy == QuantizationStrategy.GROUP:
282-
if group_size is None or group_size <= 0:
283-
raise ValueError(
284-
f"strategy {strategy} requires group_size to be "
285-
"set to a positive value"
286-
)
287-
if (
288-
group_size is not None
289-
and group_size > 0
290-
and strategy
291-
not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP)
292-
):
293-
raise ValueError("group_size requires strategy to be set to 'group'")
294-
295-
# validate activation ordering and strategy
296-
if actorder is not None and strategy != QuantizationStrategy.GROUP:
281+
# validate block strategy and structure
282+
has_block_strategy = strategy == QuantizationStrategy.BLOCK
283+
has_block_structure = block_structure is not None
284+
if has_block_strategy != has_block_structure:
285+
raise ValueError(
286+
"Block strategy requires `block_structure`, and vice versa. "
287+
f"Instead got ({strategy}, {block_structure})"
288+
)
289+
290+
# validate group strategy
291+
has_group_strategy = strategy in (
292+
QuantizationStrategy.GROUP,
293+
QuantizationStrategy.TENSOR_GROUP,
294+
)
295+
has_group_size = group_size is not None and group_size > 0
296+
has_actorder = actorder is not None
297+
if has_group_strategy != has_group_size:
298+
raise ValueError(
299+
"Group strategies require `group_size`, and vice versa. "
300+
f"Instead got ({strategy}, {group_size})"
301+
)
302+
if has_actorder and not has_group_strategy:
297303
raise ValueError(
298304
"Must use group quantization strategy in order to apply "
299305
"activation ordering"

0 commit comments

Comments
 (0)