Skip to content

Commit 0f5823c

Browse files
authored
Activation Ordering Strategies (#146)
1 parent 4328b41 commit 0f5823c

File tree

5 files changed

+119
-55
lines changed

5 files changed

+119
-55
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,15 @@ def compress(
108108
prefix = name[: -(len(weight_suffix))]
109109
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
110110
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
111+
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
111112
if scale is not None:
112113
# weight is quantized, compress it
113114
quant_args = names_to_scheme[prefix]
114115
compressed_data = self.compress_weight(
115116
weight=value,
116117
scale=scale,
117118
zero_point=zp,
119+
g_idx=g_idx,
118120
quantization_args=quant_args,
119121
device="cpu",
120122
)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
wrap_module_forward_quantized,
2222
)
2323
from compressed_tensors.quantization.quant_args import (
24+
ActivationOrdering,
2425
QuantizationArgs,
2526
QuantizationStrategy,
2627
)
@@ -179,8 +180,8 @@ def _initialize_scale_zero_point_observer(
179180
)
180181
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
181182

182-
# initialize with empty for actorder, to be populated by GPTQ or state_dict
183-
if quantization_args.actorder:
183+
# only grouped activation ordering has g_idx
184+
if quantization_args.actorder == ActivationOrdering.GROUP:
184185
g_idx_shape = (weight_shape[1],)
185186
g_idx_dtype = torch.int
186187
init_g_idx = Parameter(

src/compressed_tensors/quantization/quant_args.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from enum import Enum
16-
from typing import Any, Dict, Optional
16+
from typing import Any, Dict, Optional, Union
1717

1818
import torch
1919
from pydantic import BaseModel, Field, field_validator, model_validator
@@ -25,6 +25,7 @@
2525
"QuantizationStrategy",
2626
"QuantizationArgs",
2727
"round_to_quantized_type",
28+
"ActivationOrdering",
2829
]
2930

3031
FP8_DTYPE = torch.float8_e4m3fn
@@ -51,6 +52,19 @@ class QuantizationStrategy(str, Enum):
5152
TOKEN = "token"
5253

5354

55+
class ActivationOrdering(str, Enum):
56+
"""
57+
Enum storing strategies for activation ordering
58+
59+
Group: reorder groups and weight\n
60+
Weight: only reorder weight, not groups. Slightly lower latency and
61+
accuracy compared to group actorder\n
62+
"""
63+
64+
GROUP = "group"
65+
WEIGHT = "weight"
66+
67+
5468
class QuantizationArgs(BaseModel, use_enum_values=True):
5569
"""
5670
User facing arguments used to define a quantization config for weights or
@@ -69,17 +83,17 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
6983
quantization. Note that enabling dynamic quantization will change the default
7084
observer to a memoryless one
7185
:param actorder: whether to apply group quantization in decreasing order of
72-
activation. Defaults to False for arbitrary ordering
86+
activation. Defaults to None for arbitrary ordering
7387
"""
7488

7589
num_bits: int = 8
76-
type: QuantizationType = QuantizationType.INT.value
90+
type: QuantizationType = QuantizationType.INT
7791
symmetric: bool = True
7892
group_size: Optional[int] = None
7993
strategy: Optional[QuantizationStrategy] = None
8094
block_structure: Optional[str] = None
8195
dynamic: bool = False
82-
actorder: bool = False
96+
actorder: Optional[ActivationOrdering] = None
8397
observer: str = Field(
8498
default="minmax",
8599
description=(
@@ -108,8 +122,15 @@ def get_observer(self):
108122

109123
return Observer.load_from_registry(self.observer, quantization_args=self)
110124

125+
@field_validator("type", mode="before")
126+
def validate_type(cls, value) -> QuantizationType:
127+
if isinstance(value, str):
128+
return QuantizationType(value.lower())
129+
130+
return value
131+
111132
@field_validator("group_size", mode="before")
112-
def validate_group(cls, value) -> int:
133+
def validate_group(cls, value) -> Union[int, None]:
113134
if value is None:
114135
return value
115136

@@ -121,18 +142,29 @@ def validate_group(cls, value) -> int:
121142

122143
return value
123144

124-
@model_validator(mode="before")
125-
def validate_strategy(values) -> Dict[str, Any]:
126-
model_fields = QuantizationArgs.model_fields
127-
strategy = values.get("strategy", model_fields["strategy"].default)
128-
group_size = values.get("group_size", model_fields["group_size"].default)
129-
actorder = values.get("actorder", model_fields["actorder"].default)
145+
@field_validator("strategy", mode="before")
146+
def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
147+
if isinstance(value, str):
148+
return QuantizationStrategy(value.lower())
130149

131-
if strategy is not None:
132-
strategy = QuantizationStrategy(strategy.lower())
150+
return value
133151

134-
else:
135-
# use group_size to determinine strategy if not given explicity
152+
@field_validator("actorder", mode="before")
153+
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
154+
if isinstance(value, str):
155+
return ActivationOrdering(value.lower())
156+
157+
return value
158+
159+
@model_validator(mode="after")
160+
def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
161+
# extract user-passed values from dictionary
162+
strategy = model.strategy
163+
group_size = model.group_size
164+
actorder = model.actorder
165+
166+
# infer strategy
167+
if strategy is None:
136168
if group_size is None:
137169
strategy = QuantizationStrategy.TENSOR
138170
elif group_size > 0:
@@ -145,21 +177,24 @@ def validate_strategy(values) -> Dict[str, Any]:
145177
"strategy='group' and group_size = -1 for 'channel'"
146178
)
147179

180+
# validate strategy and group
148181
if strategy == QuantizationStrategy.GROUP:
149182
if group_size is None or group_size <= 0:
150183
raise ValueError(
151184
f"strategy {strategy} requires group_size to be "
152185
"set to a positive value"
153186
)
154187

155-
if actorder and strategy != QuantizationStrategy.GROUP:
188+
# validate activation ordering and strategy
189+
if actorder is not None and strategy != QuantizationStrategy.GROUP:
156190
raise ValueError(
157-
"Group quantization must be specified in order to apply "
191+
"Must use group quantization strategy in order to apply "
158192
"activation ordering"
159193
)
160194

161-
values["strategy"] = strategy
162-
return values
195+
# write back modified values
196+
model.strategy = strategy
197+
return model
163198

164199
def pytorch_dtype(self) -> torch.dtype:
165200
if self.type == QuantizationType.FLOAT:

tests/test_compressors/test_pack_quant.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,33 +33,32 @@
3333
apply_quantization_status,
3434
)
3535
from compressed_tensors.quantization.lifecycle.forward import fake_quantize
36+
from compressed_tensors.quantization.quant_args import ActivationOrdering
3637
from safetensors.torch import save_file
3738
from torch.nn.modules import Linear, Sequential
3839

3940

40-
def get_dummy_quant_config(num_bits=4, strategy=None, group_size=None):
41+
def get_dummy_quant_config(num_bits=4, strategy=None, group_size=None, actorder=None):
4142
config_groups = {
4243
"group_1": QuantizationScheme(
4344
targets=["Linear"],
4445
weights=QuantizationArgs(
4546
num_bits=num_bits,
4647
strategy=strategy,
4748
group_size=group_size,
49+
actorder=actorder,
4850
),
4951
),
5052
}
51-
ignore = ["lm_head"]
52-
quant_config = QuantizationConfig(
53-
config_groups=config_groups,
54-
ignore=ignore,
55-
)
56-
57-
return quant_config
53+
return QuantizationConfig(config_groups=config_groups)
5854

5955

6056
def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
6157
perm = torch.randperm(columns)
62-
return torch.tensor([index // group_size for index in range(columns)])[perm]
58+
return torch.nn.Parameter(
59+
(torch.arange(columns, dtype=torch.int) // group_size)[perm],
60+
requires_grad=False,
61+
)
6362

6463

6564
@pytest.mark.parametrize(
@@ -199,29 +198,34 @@ def test_reload_match(tmp_path, num_bits):
199198

200199

201200
@pytest.mark.parametrize(
202-
"apply_gptq",
203-
[True, False],
201+
"actorder",
202+
[
203+
ActivationOrdering.GROUP,
204+
ActivationOrdering.WEIGHT,
205+
None,
206+
],
204207
)
205-
def test_actorder_reload_match(apply_gptq, tmp_path):
206-
model = Sequential(
207-
OrderedDict(
208-
[
209-
("dummy", Linear(512, 1024, bias=None)),
210-
]
211-
)
212-
)
208+
def test_actorder_reload_match(actorder, tmp_path):
209+
model = Sequential(OrderedDict([("dummy", Linear(512, 1024, bias=None))]))
213210
group_size = 128
214-
quant_config = get_dummy_quant_config(strategy="group", group_size=group_size)
211+
quant_config = get_dummy_quant_config(
212+
strategy="group", group_size=group_size, actorder=actorder
213+
)
215214
apply_quantization_config(model, quant_config)
216-
apply_quantization_status(model, QuantizationStatus.CALIBRATION)
217-
218-
if apply_gptq:
219-
model.dummy.weight_g_idx = make_dummy_g_idx(512, group_size)
220215

216+
# run calibration
217+
apply_quantization_status(model, QuantizationStatus.CALIBRATION)
221218
for _ in range(16):
222219
inputs = torch.rand((512, 512))
223220
_ = model(inputs)
221+
apply_quantization_status(model, QuantizationStatus.FROZEN)
222+
223+
# apply gptq
224+
if actorder == ActivationOrdering.GROUP:
225+
init_g_idx = make_dummy_g_idx(512, group_size)
226+
model.dummy.register_parameter("weight_g_idx", init_g_idx)
224227

228+
# compress
225229
compressor = PackedQuantizationCompressor(config=quant_config)
226230
quantized_modules_to_args = {
227231
"dummy": quant_config.config_groups["group_1"].weights,
@@ -230,6 +234,8 @@ def test_actorder_reload_match(apply_gptq, tmp_path):
230234
model.state_dict(), names_to_scheme=quantized_modules_to_args
231235
)
232236
save_file(compressed_state_dict, tmp_path / "model.safetensors")
237+
238+
# decompress
233239
reconstructed_dense_gen = compressor.decompress(
234240
tmp_path, names_to_scheme=quantized_modules_to_args
235241
)
@@ -241,6 +247,7 @@ def test_actorder_reload_match(apply_gptq, tmp_path):
241247
model.dummy.weight,
242248
scale=model.dummy.weight_scale,
243249
zero_point=model.dummy.weight_zero_point,
250+
g_idx=getattr(model.dummy, "weight_g_idx", None),
244251
args=quantized_modules_to_args["dummy"],
245252
)
246253
assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"])

tests/test_quantization/test_quant_args.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616
from compressed_tensors.quantization import (
17+
ActivationOrdering,
1718
QuantizationArgs,
1819
QuantizationStrategy,
1920
QuantizationType,
@@ -39,6 +40,9 @@ def test_group():
3940
assert group.strategy == QuantizationStrategy.GROUP
4041
assert group.group_size == kwargs["group_size"]
4142

43+
with pytest.raises(ValueError):
44+
QuantizationArgs(strategy=QuantizationStrategy.GROUP, group_size=-1)
45+
4246

4347
def test_block():
4448
kwargs = {"strategy": "block", "block_structure": "2x4"}
@@ -56,25 +60,40 @@ def test_infer_strategy():
5660
assert args.strategy == QuantizationStrategy.CHANNEL
5761

5862

63+
def test_enums():
64+
assert QuantizationArgs(
65+
type=QuantizationType.INT,
66+
strategy=QuantizationStrategy.GROUP,
67+
actorder=ActivationOrdering.WEIGHT,
68+
group_size=1,
69+
) == QuantizationArgs(type="InT", strategy="GROUP", actorder="weight", group_size=1)
70+
71+
5972
def test_actorder():
60-
args = QuantizationArgs(group_size=128, actorder=True)
73+
# test group inference with actorder
74+
args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.GROUP)
6175
assert args.strategy == QuantizationStrategy.GROUP
62-
assert args.actorder
6376

77+
# test invalid pairings
6478
with pytest.raises(ValueError):
65-
args = QuantizationArgs(group_size=None, actorder=True)
66-
79+
QuantizationArgs(group_size=None, actorder="weight")
6780
with pytest.raises(ValueError):
68-
args = QuantizationArgs(group_size=-1, actorder=True)
69-
81+
QuantizationArgs(group_size=-1, actorder="weight")
7082
with pytest.raises(ValueError):
71-
args = QuantizationArgs(strategy="tensor", actorder=True)
83+
QuantizationArgs(strategy="tensor", actorder="weight")
84+
85+
# test boolean defaulting
86+
assert (
87+
QuantizationArgs(group_size=1, actorder="weight").actorder
88+
== ActivationOrdering.WEIGHT
89+
)
90+
assert QuantizationArgs(group_size=1, actorder=None).actorder is None
7291

7392

7493
def test_invalid():
7594
with pytest.raises(ValidationError):
76-
_ = QuantizationArgs(type="invalid")
95+
QuantizationArgs(type="invalid")
7796
with pytest.raises(ValidationError):
78-
_ = QuantizationArgs(strategy="invalid")
97+
QuantizationArgs(strategy="invalid")
7998
with pytest.raises(ValidationError):
80-
_ = QuantizationArgs(strategy=QuantizationStrategy.GROUP)
99+
QuantizationArgs(strategy=QuantizationStrategy.GROUP)

0 commit comments

Comments
 (0)