Skip to content

Commit 4b790ec

Browse files
[MOE] Add a set of functionalities to support quantization of MOE models (#46)
* Update base.py * add token counter * implemented token counting * observer to count throughout calibration * cleanup tests * avoid circular dep on import * Update helpers.py * fix tests * post rebase fixes * Update src/compressed_tensors/quantization/observers/helpers.py --------- Co-authored-by: [email protected] <[email protected]>
1 parent aecb127 commit 4b790ec

File tree

8 files changed

+190
-10
lines changed

8 files changed

+190
-10
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,11 @@ def maybe_calibrate_or_quantize(
293293
}:
294294
return value
295295

296+
if value.numel() == 0:
297+
# if the tensor is empty,
298+
# skip quantization
299+
return value
300+
296301
if args.dynamic:
297302
# dynamic quantization - get scale and zero point directly from observer
298303
observer = getattr(module, f"{base_name}_observer")

src/compressed_tensors/quantization/observers/base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
1516
from typing import Any, Iterable, Optional, Tuple, Union
1617

1718
import torch
@@ -24,6 +25,9 @@
2425
from torch.nn import Module
2526

2627

28+
_LOGGER = logging.getLogger(__name__)
29+
30+
2731
__all__ = ["Observer"]
2832

2933

@@ -39,6 +43,7 @@ def __init__(self, quantization_args: QuantizationArgs):
3943
super().__init__()
4044
self._scale = None
4145
self._zero_point = None
46+
self._num_observed_tokens = None
4247

4348
@torch.no_grad()
4449
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
@@ -48,6 +53,7 @@ def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
4853
from
4954
:return: tuple of scale and zero point based on last observed value
5055
"""
56+
self.record_observed_tokens(observed)
5157
return self.get_qparams(observed=observed)
5258

5359
def calculate_qparams(
@@ -132,3 +138,36 @@ def get_qparams_along_dim(
132138
return self.calculate_qparams(
133139
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
134140
)
141+
142+
def record_observed_tokens(self, batch_tensor: Tensor):
143+
"""
144+
Counts the number of tokens observed during the
145+
forward passes. The count is aggregated in the
146+
_num_observed_tokens attribute of the class.
147+
148+
Note: The batch_tensor is expected to have two dimensions
149+
(batch_size * sequence_length, num_features). This is the
150+
general shape expected by the forward pass of the expert
151+
layers in a MOE model. If the input tensor does not have
152+
two dimensions, the _num_observed_tokens attribute will be set
153+
to None.
154+
"""
155+
if not isinstance(batch_tensor, Tensor):
156+
raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
157+
158+
if batch_tensor.ndim != 2:
159+
_LOGGER.debug(
160+
"The input tensor is expected to have two dimensions "
161+
"(batch_size * sequence_length, num_features). "
162+
f"The input tensor has {batch_tensor.ndim} dimensions."
163+
)
164+
return
165+
166+
if self._num_observed_tokens is None:
167+
# initialize the count
168+
self._num_observed_tokens = 0
169+
170+
# batch_tensor (batch_size * sequence_length, num_features)
171+
# observed_tokens (batch_size * sequence_length)
172+
observed_tokens, _ = batch_tensor.shape
173+
self._num_observed_tokens += observed_tokens

src/compressed_tensors/quantization/observers/helpers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import Counter
1516
from typing import Tuple
1617

1718
import torch
@@ -23,16 +24,33 @@
2324
from torch import FloatTensor, IntTensor, Tensor
2425

2526

26-
__all__ = ["calculate_qparams", "calculate_range"]
27+
__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
28+
29+
30+
def get_observer_token_count(module: torch.nn.Module) -> Counter:
31+
"""
32+
Parse the module and return the number of tokens observed by
33+
each module's observer.
34+
35+
:param module: module to parse
36+
:return: counter with the number of tokens observed by each observer
37+
"""
38+
token_counts = Counter()
39+
for name, module in module.named_modules():
40+
if name.endswith(".input_observer"):
41+
token_counts[
42+
name.replace(".input_observer", "")
43+
] = module._num_observed_tokens
44+
return token_counts
2745

2846

2947
def calculate_qparams(
3048
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
3149
) -> Tuple[FloatTensor, IntTensor]:
3250
"""
33-
:param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
51+
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
3452
from
35-
:param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
53+
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
3654
from
3755
:param quantization_args: settings to quantization
3856
:return: tuple of the calculated scale(s) and zero point(s)

tests/test_quantization/lifecycle/test_forward.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,24 @@ def test_maybe_calibrate_or_quantize(create_quantization_scheme, quantization_st
5757
quantization_args = QuantizationArgs(num_bits=num_bits, symmetric=True)
5858
layer = Linear(4, 4)
5959
layer.weight.data *= 100
60+
61+
dummy_tensor = torch.randn(8, 4) # (num_tokens, num_features)
6062
layer.quantization_status = QuantizationStatus(quantization_status)
6163

6264
initialize_module_for_quantization(layer, quantization_scheme)
6365

6466
# only calibration updates the scale and zero-point
6567
if layer.quantization_status == QuantizationStatus.INITIALIZED:
6668
out = maybe_calibrate_or_quantize(
67-
layer, layer.weight.data, "input", quantization_args
69+
layer, dummy_tensor, "input", quantization_args
6870
)
69-
assert torch.allclose(out, layer.weight.data)
71+
assert torch.allclose(out, dummy_tensor)
7072
elif layer.quantization_status == QuantizationStatus.CALIBRATION:
71-
7273
out = maybe_calibrate_or_quantize(
73-
layer, layer.weight.data, "input", quantization_args
74+
layer, dummy_tensor, "input", quantization_args
7475
)
75-
assert torch.allclose(out, layer.weight.data, atol=0.2)
76-
76+
assert torch.allclose(out, dummy_tensor, atol=0.2)
77+
assert layer.input_observer._num_observed_tokens == dummy_tensor.shape[0]
7778
elif layer.quantization_status == QuantizationStatus.FROZEN:
7879
# scale and zero points are empty -- cannot quantize
7980
with pytest.raises(Exception):
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from compressed_tensors.quantization import (
16+
QuantizationConfig,
17+
apply_quantization_config,
18+
)
19+
from compressed_tensors.quantization.observers.helpers import get_observer_token_count
20+
from transformers import AutoModelForCausalLM, AutoTokenizer
21+
22+
23+
def test_get_observer_token_count():
24+
model = AutoModelForCausalLM.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE")
25+
tokenizer = AutoTokenizer.from_pretrained("Isotonic/TinyMixtral-4x248M-MoE")
26+
model.eval()
27+
config = QuantizationConfig(
28+
format="fakequant",
29+
quantization_status="calibration",
30+
config_groups={
31+
"group_1": {
32+
"input_activations": {
33+
"num_bits": 8,
34+
"type": "int",
35+
"symmetric": False,
36+
"strategy": "tensor",
37+
},
38+
"targets": ["Linear"],
39+
},
40+
},
41+
)
42+
apply_quantization_config(model, config)
43+
44+
# start calibration
45+
calib_list = [
46+
"I am a string that",
47+
"is used for calibration so",
48+
"that your model is",
49+
"quantized properly.",
50+
]
51+
52+
total_num_tokens_observed = 0
53+
for calib_sample in calib_list:
54+
calib_tensor = tokenizer(calib_sample, return_tensors="pt")
55+
_ = model(**calib_tensor)
56+
total_num_tokens_observed += len(calib_tensor.input_ids.flatten())
57+
58+
counter = get_observer_token_count(model)
59+
60+
# filter out the None values
61+
# (tokens, in the appropriate format, that were not observed by the model)
62+
counter = {k: v for k, v in counter.items() if v is not None}
63+
64+
# iterate over all the layers in the model where the token count in the proper
65+
# format is has been observed
66+
for i in range(model.config.num_hidden_layers):
67+
# fetch the tokens observed by the router
68+
tokens_observed_by_router = counter.pop(
69+
f"model.layers.{i}.block_sparse_moe.gate"
70+
)
71+
assert tokens_observed_by_router == total_num_tokens_observed
72+
73+
# fetch the sum of tokens observed by all the experts
74+
sum_tokens_observed_by_experts = 0
75+
keys_for_this_layer = [
76+
k
77+
for k in counter.keys()
78+
if f"model.layers.{i}.block_sparse_moe.experts" in k
79+
]
80+
for key in keys_for_this_layer:
81+
sum_tokens_observed_by_experts += counter.pop(key)
82+
83+
# each Mixtral expert is comprised of 3 linear layers,
84+
# so we need to multiply by 3
85+
assert (
86+
sum_tokens_observed_by_experts
87+
== total_num_tokens_observed * model.config.num_experts_per_tok * 3
88+
)
89+
90+
# there are no more information in the counter
91+
assert len(counter) == 0

tests/test_quantization/test_observers/test_min_max.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_min_max_observer_value_update():
5959

6060
delta = 1e-6
6161

62-
# udpate the min, max twice total
62+
# update the min, max twice total
6363
tensors = [
6464
inp,
6565
inp,

tests/test_utils/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)