Skip to content

Commit 724d5ce

Browse files
authored
Implement aliasable mixin and alias activation ordering (#213)
* implement aliasable mixin and alias activation ordering Signed-off-by: Kyle Sayers <[email protected]> * update docstring Signed-off-by: Kyle Sayers <[email protected]> * fix docstring Signed-off-by: Kyle Sayers <[email protected]> * uncomment Signed-off-by: Kyle Sayers <[email protected]> * rename and make abstract Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 525ef3a commit 724d5ce

File tree

4 files changed

+100
-4
lines changed

4 files changed

+100
-4
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Dict, Optional, Union
1818

1919
import torch
20+
from compressed_tensors.utils import Aliasable
2021
from pydantic import BaseModel, Field, field_validator, model_validator
2122

2223

@@ -53,17 +54,30 @@ class QuantizationStrategy(str, Enum):
5354
TOKEN = "token"
5455

5556

56-
class ActivationOrdering(str, Enum):
57+
class ActivationOrdering(Aliasable, str, Enum):
5758
"""
5859
Enum storing strategies for activation ordering
5960
6061
Group: reorder groups and weight\n
61-
Weight: only reorder weight, not groups. Slightly lower latency and
62-
accuracy compared to group actorder\n
62+
Weight: only reorder weight, not groups. Slightly lower accuracy but also lower
63+
latency when compared to group actorder\n
64+
Dynamic: alias for Group\n
65+
Static: alias for Weight\n
6366
"""
6467

6568
GROUP = "group"
6669
WEIGHT = "weight"
70+
# aliases
71+
DYNAMIC = "dynamic"
72+
STATIC = "static"
73+
74+
@property
75+
@staticmethod
76+
def aliases(self) -> Dict[str, str]:
77+
return {
78+
"dynamic": "group",
79+
"static": "weight",
80+
}
6781

6882

6983
class QuantizationArgs(BaseModel, use_enum_values=True):

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
6262

6363
return model
6464

65+
6566
"""
6667
Pre-Set Quantization Scheme Args
6768
"""

src/compressed_tensors/utils/helpers.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Optional
15+
from abc import abstractmethod
16+
from typing import Any, Dict, Optional
1617

1718
import torch
1819
from transformers import AutoConfig
@@ -24,6 +25,7 @@
2425
"tensor_follows_mask_structure",
2526
"replace_module",
2627
"is_compressed_tensors_config",
28+
"Aliasable",
2729
]
2830

2931
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -119,3 +121,36 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
119121
return isinstance(compression_config, CompressedTensorsConfig)
120122
except ImportError:
121123
return False
124+
125+
126+
class Aliasable:
127+
"""
128+
A mixin for enums to allow aliasing of enum members
129+
130+
Example:
131+
>>> class MyClass(Aliasable, int, Enum):
132+
>>> ...
133+
"""
134+
135+
@property
136+
@staticmethod
137+
@abstractmethod
138+
def aliases(self) -> Dict[str, str]:
139+
raise NotImplementedError()
140+
141+
def __eq__(self, other):
142+
if isinstance(other, self.__class__):
143+
return self.value == other.value or (
144+
self.aliases.get(self.value, self.value)
145+
== self.aliases.get(other.value, other.value)
146+
)
147+
else:
148+
self_value = self.aliases.get(self.value, self.value)
149+
other_value = self.aliases.get(other, other)
150+
return self_value == other_value
151+
152+
return False
153+
154+
def __hash__(self):
155+
canonical_value = self.aliases.get(self.value, self.value)
156+
return hash(canonical_value)

tests/test_quantization/test_quant_args.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,28 @@ def test_actorder():
8383
# test group inference with actorder
8484
args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.GROUP)
8585
assert args.strategy == QuantizationStrategy.GROUP
86+
args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.DYNAMIC)
87+
assert args.strategy == QuantizationStrategy.GROUP
8688

8789
# test invalid pairings
90+
with pytest.raises(ValueError):
91+
QuantizationArgs(group_size=None, actorder="group")
8892
with pytest.raises(ValueError):
8993
QuantizationArgs(group_size=None, actorder="weight")
94+
with pytest.raises(ValueError):
95+
QuantizationArgs(group_size=None, actorder="static")
96+
with pytest.raises(ValueError):
97+
QuantizationArgs(group_size=-1, actorder="group")
9098
with pytest.raises(ValueError):
9199
QuantizationArgs(group_size=-1, actorder="weight")
100+
with pytest.raises(ValueError):
101+
QuantizationArgs(group_size=-1, actorder="static")
102+
with pytest.raises(ValueError):
103+
QuantizationArgs(strategy="tensor", actorder="group")
92104
with pytest.raises(ValueError):
93105
QuantizationArgs(strategy="tensor", actorder="weight")
106+
with pytest.raises(ValueError):
107+
QuantizationArgs(strategy="tensor", actorder="static")
94108

95109
# test boolean and none defaulting
96110
assert (
@@ -101,6 +115,38 @@ def test_actorder():
101115
assert QuantizationArgs(group_size=1, actorder=None).actorder is None
102116

103117

118+
def test_actorder_aliases():
119+
assert (
120+
ActivationOrdering.GROUP
121+
== ActivationOrdering.DYNAMIC
122+
== ActivationOrdering.GROUP
123+
)
124+
assert (
125+
ActivationOrdering.WEIGHT
126+
== ActivationOrdering.STATIC
127+
== ActivationOrdering.WEIGHT
128+
)
129+
130+
assert ActivationOrdering.GROUP == "dynamic" == ActivationOrdering.GROUP
131+
assert ActivationOrdering.DYNAMIC == "dynamic" == ActivationOrdering.DYNAMIC
132+
assert ActivationOrdering.GROUP == "group" == ActivationOrdering.GROUP
133+
assert ActivationOrdering.DYNAMIC == "group" == ActivationOrdering.DYNAMIC
134+
135+
assert ActivationOrdering.WEIGHT == "static" == ActivationOrdering.WEIGHT
136+
assert ActivationOrdering.STATIC == "static" == ActivationOrdering.STATIC
137+
assert ActivationOrdering.WEIGHT == "weight" == ActivationOrdering.WEIGHT
138+
assert ActivationOrdering.STATIC == "weight" == ActivationOrdering.STATIC
139+
140+
assert ActivationOrdering.WEIGHT != "dynamic" != ActivationOrdering.WEIGHT
141+
assert ActivationOrdering.STATIC != "dynamic" != ActivationOrdering.STATIC
142+
assert ActivationOrdering.WEIGHT != "group" != ActivationOrdering.WEIGHT
143+
assert ActivationOrdering.STATIC != "group" != ActivationOrdering.STATIC
144+
assert ActivationOrdering.GROUP != "static" != ActivationOrdering.GROUP
145+
assert ActivationOrdering.DYNAMIC != "static" != ActivationOrdering.DYNAMIC
146+
assert ActivationOrdering.GROUP != "weight" != ActivationOrdering.GROUP
147+
assert ActivationOrdering.DYNAMIC != "weight" != ActivationOrdering.DYNAMIC
148+
149+
104150
def test_invalid():
105151
with pytest.raises(ValidationError):
106152
QuantizationArgs(type="invalid")

0 commit comments

Comments
 (0)