Skip to content

Commit fb6aa9a

Browse files
add ALL_QPARAM_KEYS var
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 88b8865 commit fb6aa9a

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"initialize_module_for_quantization",
4646
"is_attention_module",
4747
"KVCacheScaleType",
48+
"ALL_QPARAM_KEYS",
4849
]
4950

5051

@@ -56,6 +57,18 @@ class KVCacheScaleType(Enum):
5657
VALUE = "v_scale"
5758

5859

60+
ALL_QPARAM_KEYS = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [
61+
f"{base_name}_{suffix}"
62+
for base_name in ("input", "weight", "output")
63+
for suffix in (
64+
"global_scale",
65+
"scale",
66+
"zero_point",
67+
"g_idx",
68+
)
69+
]
70+
71+
5972
def initialize_module_for_quantization(
6073
module: Module,
6174
scheme: Optional[QuantizationScheme] = None,
@@ -146,17 +159,7 @@ def _clear_all_qparams(
146159
147160
:param module: module to clear qparams from
148161
"""
149-
keys = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [
150-
f"{base_name}_{suffix}"
151-
for base_name in ("input", "weight", "output")
152-
for suffix in (
153-
"global_scale",
154-
"scale",
155-
"zero_point",
156-
"g_idx",
157-
)
158-
]
159-
for key in keys:
162+
for key in ALL_QPARAM_KEYS:
160163
if hasattr(module, key):
161164
delete_offload_parameter(module, key)
162165

0 commit comments

Comments
 (0)