Skip to content

Commit 42ee086

Browse files
committed
initialize_qparams
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 70299f3 commit 42ee086

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
__all__ = [
4848
"initialize_module_for_quantization",
4949
"is_attention_module",
50+
"initialize_qparams",
5051
]
5152

5253

@@ -100,7 +101,7 @@ def initialize_module_for_quantization(
100101
return
101102

102103
if scheme.input_activations is not None:
103-
_initialize_scale_zero_point(
104+
initialize_qparams(
104105
module,
105106
"input",
106107
scheme.input_activations,
@@ -110,7 +111,7 @@ def initialize_module_for_quantization(
110111
)
111112

112113
if scheme.weights is not None:
113-
_initialize_scale_zero_point(
114+
initialize_qparams(
114115
module,
115116
"weight",
116117
scheme.weights,
@@ -121,7 +122,7 @@ def initialize_module_for_quantization(
121122

122123
output_is_kv_cache = is_kv_cache_quant_scheme(scheme)
123124
if scheme.output_activations is not None and not output_is_kv_cache:
124-
_initialize_scale_zero_point(
125+
initialize_qparams(
125126
module,
126127
"output",
127128
scheme.output_activations,
@@ -147,14 +148,29 @@ def is_attention_module(module: Module):
147148
)
148149

149150

150-
def _initialize_scale_zero_point(
151+
def initialize_qparams(
151152
module: Module,
152153
base_name: str,
153154
quantization_args: QuantizationArgs,
154155
observed_shape: Tuple[int],
155156
observed_dtype: torch.dtype,
156157
force_zero_point: bool = True,
157158
):
159+
"""
160+
Initialize quantization parameters for a given basename according to the passed
161+
quantization args. The shape and dtype of the observed weight/activation must also
162+
be provided.
163+
164+
Scales will always be initialized. Global scales are initialized depending on args.
165+
Zero points will be initialized if not symmetric or if `force_zero_point` is True.
166+
167+
:param module: module to register qparams to
168+
:param base_name: base name of qparams, for example "input", "weight", "k", "v"
169+
:param quantization_args: arguments for quantization
170+
:param observed_shape: last (right-most) known dimensions of the observed weight/act
171+
:param observed_dtype: dtype of the observed weight/actt
172+
:param force_zero_point: force the zero_point parameter to be initialized
173+
"""
158174
strategy = quantization_args.strategy
159175
dynamic = quantization_args.dynamic
160176
actorder = quantization_args.actorder

0 commit comments

Comments
 (0)