Skip to content

Commit 2dcbc9d

Browse files
authored
Implement aliasable mixin and alias activation ordering (python3.9 fix) (#218)
* implement aliasable mixin and alias activation ordering Signed-off-by: Kyle Sayers <[email protected]> * update docstring Signed-off-by: Kyle Sayers <[email protected]> * fix docstring Signed-off-by: Kyle Sayers <[email protected]> * uncomment Signed-off-by: Kyle Sayers <[email protected]> * rename and make abstract Signed-off-by: Kyle Sayers <[email protected]> * remove property for clarity and to support python3.9 Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 8571339 commit 2dcbc9d

File tree

4 files changed

+96
-4
lines changed

4 files changed

+96
-4
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Dict, Optional, Union
1818

1919
import torch
20+
from compressed_tensors.utils import Aliasable
2021
from pydantic import BaseModel, Field, field_validator, model_validator
2122

2223

@@ -53,17 +54,29 @@ class QuantizationStrategy(str, Enum):
5354
TOKEN = "token"
5455

5556

56-
class ActivationOrdering(str, Enum):
57+
class ActivationOrdering(Aliasable, str, Enum):
5758
"""
5859
Enum storing strategies for activation ordering
5960
6061
Group: reorder groups and weight\n
61-
Weight: only reorder weight, not groups. Slightly lower latency and
62-
accuracy compared to group actorder\n
62+
Weight: only reorder weight, not groups. Slightly lower accuracy but also lower
63+
latency when compared to group actorder\n
64+
Dynamic: alias for Group\n
65+
Static: alias for Weight\n
6366
"""
6467

6568
GROUP = "group"
6669
WEIGHT = "weight"
70+
# aliases
71+
DYNAMIC = "dynamic"
72+
STATIC = "static"
73+
74+
@staticmethod
75+
def get_aliases() -> Dict[str, str]:
76+
return {
77+
"dynamic": "group",
78+
"static": "weight",
79+
}
6780

6881

6982
class QuantizationArgs(BaseModel, use_enum_values=True):

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
6262

6363
return model
6464

65+
6566
"""
6667
Pre-Set Quantization Scheme Args
6768
"""

src/compressed_tensors/utils/helpers.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Optional
15+
from typing import Any, Dict, Optional
1616

1717
import torch
1818
from transformers import AutoConfig
@@ -24,6 +24,7 @@
2424
"tensor_follows_mask_structure",
2525
"replace_module",
2626
"is_compressed_tensors_config",
27+
"Aliasable",
2728
]
2829

2930
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -119,3 +120,34 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
119120
return isinstance(compression_config, CompressedTensorsConfig)
120121
except ImportError:
121122
return False
123+
124+
125+
class Aliasable:
126+
"""
127+
A mixin for enums to allow aliasing of enum members
128+
129+
Example:
130+
>>> class MyClass(Aliasable, int, Enum):
131+
>>> ...
132+
"""
133+
134+
@staticmethod
135+
def get_aliases() -> Dict[str, str]:
136+
raise NotImplementedError()
137+
138+
def __eq__(self, other):
139+
if isinstance(other, self.__class__):
140+
aliases = self.get_aliases()
141+
return self.value == other.value or (
142+
aliases.get(self.value, self.value)
143+
== aliases.get(other.value, other.value)
144+
)
145+
else:
146+
aliases = self.get_aliases()
147+
self_value = aliases.get(self.value, self.value)
148+
other_value = aliases.get(other, other)
149+
return self_value == other_value
150+
151+
def __hash__(self):
152+
canonical_value = self.aliases.get(self.value, self.value)
153+
return hash(canonical_value)

tests/test_quantization/test_quant_args.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,28 @@ def test_actorder():
8383
# test group inference with actorder
8484
args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.GROUP)
8585
assert args.strategy == QuantizationStrategy.GROUP
86+
args = QuantizationArgs(group_size=128, actorder=ActivationOrdering.DYNAMIC)
87+
assert args.strategy == QuantizationStrategy.GROUP
8688

8789
# test invalid pairings
90+
with pytest.raises(ValueError):
91+
QuantizationArgs(group_size=None, actorder="group")
8892
with pytest.raises(ValueError):
8993
QuantizationArgs(group_size=None, actorder="weight")
94+
with pytest.raises(ValueError):
95+
QuantizationArgs(group_size=None, actorder="static")
96+
with pytest.raises(ValueError):
97+
QuantizationArgs(group_size=-1, actorder="group")
9098
with pytest.raises(ValueError):
9199
QuantizationArgs(group_size=-1, actorder="weight")
100+
with pytest.raises(ValueError):
101+
QuantizationArgs(group_size=-1, actorder="static")
102+
with pytest.raises(ValueError):
103+
QuantizationArgs(strategy="tensor", actorder="group")
92104
with pytest.raises(ValueError):
93105
QuantizationArgs(strategy="tensor", actorder="weight")
106+
with pytest.raises(ValueError):
107+
QuantizationArgs(strategy="tensor", actorder="static")
94108

95109
# test boolean and none defaulting
96110
assert (
@@ -101,6 +115,38 @@ def test_actorder():
101115
assert QuantizationArgs(group_size=1, actorder=None).actorder is None
102116

103117

118+
def test_actorder_aliases():
119+
assert (
120+
ActivationOrdering.GROUP
121+
== ActivationOrdering.DYNAMIC
122+
== ActivationOrdering.GROUP
123+
)
124+
assert (
125+
ActivationOrdering.WEIGHT
126+
== ActivationOrdering.STATIC
127+
== ActivationOrdering.WEIGHT
128+
)
129+
130+
assert ActivationOrdering.GROUP == "dynamic" == ActivationOrdering.GROUP
131+
assert ActivationOrdering.DYNAMIC == "dynamic" == ActivationOrdering.DYNAMIC
132+
assert ActivationOrdering.GROUP == "group" == ActivationOrdering.GROUP
133+
assert ActivationOrdering.DYNAMIC == "group" == ActivationOrdering.DYNAMIC
134+
135+
assert ActivationOrdering.WEIGHT == "static" == ActivationOrdering.WEIGHT
136+
assert ActivationOrdering.STATIC == "static" == ActivationOrdering.STATIC
137+
assert ActivationOrdering.WEIGHT == "weight" == ActivationOrdering.WEIGHT
138+
assert ActivationOrdering.STATIC == "weight" == ActivationOrdering.STATIC
139+
140+
assert ActivationOrdering.WEIGHT != "dynamic" != ActivationOrdering.WEIGHT
141+
assert ActivationOrdering.STATIC != "dynamic" != ActivationOrdering.STATIC
142+
assert ActivationOrdering.WEIGHT != "group" != ActivationOrdering.WEIGHT
143+
assert ActivationOrdering.STATIC != "group" != ActivationOrdering.STATIC
144+
assert ActivationOrdering.GROUP != "static" != ActivationOrdering.GROUP
145+
assert ActivationOrdering.DYNAMIC != "static" != ActivationOrdering.DYNAMIC
146+
assert ActivationOrdering.GROUP != "weight" != ActivationOrdering.GROUP
147+
assert ActivationOrdering.DYNAMIC != "weight" != ActivationOrdering.DYNAMIC
148+
149+
104150
def test_invalid():
105151
with pytest.raises(ValidationError):
106152
QuantizationArgs(type="invalid")

0 commit comments

Comments
 (0)