Skip to content

Commit bf74a58

Browse files
irenabirenab
authored andcommitted
remove activation_quantization_params_fn from NodeActivationQuantizationConf
1 parent 011c876 commit bf74a58

15 files changed

Lines changed: 146 additions & 229 deletions

File tree

model_compression_toolkit/core/common/network_editors/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
from model_compression_toolkit.core.common.network_editors.actions import ChangeCandidatesWeightsQuantConfigAttr, ChangeFinalWeightsQuantConfigAttr, ChangeCandidatesActivationQuantConfigAttr, ChangeQuantizationParamFunction, ChangeCandidatesActivationQuantizationMethod, ChangeFinalWeightsQuantizationMethod, ChangeCandidatesWeightsQuantizationMethod, ChangeFinalActivationQuantConfigAttr
1716
from model_compression_toolkit.core.common.network_editors.actions import EditRule
1817
from model_compression_toolkit.core.common.network_editors.node_filters import NodeTypeFilter, NodeNameScopeFilter, \
1918
NodeNameFilter

model_compression_toolkit/core/common/network_editors/actions.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@
2222
from model_compression_toolkit.logger import Logger
2323

2424

25-
from model_compression_toolkit.core.common.framework_info import get_fw_info
2625
from model_compression_toolkit.core.common.graph.base_node import BaseNode
2726
from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
28-
get_activation_quantization_params_fn, get_weights_quantization_params_fn
27+
get_weights_quantization_params_fn
2928
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
3029
get_weights_quantization_fn
3130

@@ -174,47 +173,6 @@ def apply(self, node: BaseNode, graph):
174173
node.final_activation_quantization_cfg.set_quant_config_attr(parameter_name, parameter_value)
175174

176175

177-
class ChangeQuantizationParamFunction(BaseAction):
178-
"""
179-
Class ChangeQuantizationParamFunction to change a node's weights/activations quantization params function.
180-
"""
181-
182-
def __init__(self,
183-
attr_name: str = None,
184-
activation_quantization_params_fn: Callable = None,
185-
weights_quantization_params_fn: Callable = None):
186-
"""
187-
Init a ChangeQuantizationParamFunction object.
188-
189-
Args:
190-
attr_name: The weights attribute's name to set the weights quantization params function for (if setting weights params).
191-
activation_quantization_params_fn: a params function for a node's activations.
192-
weights_quantization_params_fn: a params function for a node's weights.
193-
"""
194-
self.activation_quantization_params_fn = activation_quantization_params_fn
195-
self.weights_quantization_params_fn = weights_quantization_params_fn
196-
self.attr_name = attr_name
197-
198-
def apply(self, node: BaseNode, graph):
199-
"""
200-
Change the node's weights/activations quantization params function.
201-
202-
Args:
203-
node: Node object to change its quantization params function.
204-
graph: Graph to apply the action on.
205-
206-
Returns:
207-
The node after its quantization params function has been modified.
208-
"""
209-
for nqc in node.candidates_quantization_cfg:
210-
if self.activation_quantization_params_fn is not None:
211-
nqc.activation_quantization_cfg.set_activation_quantization_params_fn(
212-
self.activation_quantization_params_fn)
213-
if self.weights_quantization_params_fn is not None:
214-
attr_config = nqc.weights_quantization_cfg.get_attr_config(self.attr_name)
215-
attr_config.override_weights_quantization_params_fn(self.weights_quantization_params_fn)
216-
217-
218176
class ChangeFinalActivationQuantizationMethod(BaseAction):
219177
"""
220178
Class ChangeFinalActivationQuantizationMethod to change a node's weights/activations quantizer function.
@@ -243,13 +201,6 @@ def apply(self, node: BaseNode, graph):
243201
"""
244202

245203
if self.activation_quantization_method is not None and node.final_activation_quantization_cfg is not None:
246-
247-
activation_quantization_params_fn = get_activation_quantization_params_fn(
248-
self.activation_quantization_method)
249-
250-
node.final_activation_quantization_cfg.set_activation_quantization_params_fn(
251-
activation_quantization_params_fn)
252-
253204
node.final_activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
254205

255206

@@ -278,9 +229,6 @@ def apply(self, node: BaseNode, graph):
278229
"""
279230
if self.activation_quantization_method is not None:
280231
for qc in node.candidates_quantization_cfg:
281-
activation_quantization_params_fn = get_activation_quantization_params_fn(
282-
self.activation_quantization_method)
283-
qc.activation_quantization_cfg.set_activation_quantization_params_fn(activation_quantization_params_fn)
284232
qc.activation_quantization_cfg.activation_quantization_method = self.activation_quantization_method
285233

286234

model_compression_toolkit/core/common/quantization/node_quantization_config.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,14 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
8181
"""
8282
Attributes for configuring the quantization of the activations of a node.
8383
"""
84-
def __init__(self,
85-
op_cfg: OpQuantizationConfig,
86-
activation_quantization_params_fn: Callable):
84+
def __init__(self, op_cfg: OpQuantizationConfig):
8785
"""
8886
8987
Args:
9088
op_cfg: OpQuantizationConfig of the node with quantizers types to use when creating node quantization configuration.
91-
activation_quantization_params_fn: Function to use when computing the threshold for quantizing a node's activations.
9289
"""
93-
94-
self.activation_quantization_params_fn = activation_quantization_params_fn
9590
self.activation_quantization_method = op_cfg.activation_quantization_method
9691
self.activation_n_bits = op_cfg.activation_n_bits
97-
self.activation_quantization_params = {}
98-
# TODO irena: computed by compute_activation_bias_correction. shouldnt really be here
99-
self.activation_bias_correction_term = None
10092
if op_cfg.enable_activation_quantization and op_cfg.quantization_preserving:
10193
raise ValueError("An OpQuantizationConfig can't have both enable_activation_quantization and quantization_preserving enabled.")
10294
if op_cfg.enable_activation_quantization:
@@ -107,6 +99,10 @@ def __init__(self,
10799
self.quant_mode = ActivationQuantizationMode.NO_QUANT
108100
self.signedness = op_cfg.signedness
109101

102+
self.activation_quantization_params = {}
103+
# TODO irena: computed by compute_activation_bias_correction. shouldnt really be here
104+
self.activation_bias_correction_term = None
105+
110106
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
111107
self.activation_error_method = None
112108
self.relu_bound_to_power_of_2 = None
@@ -146,16 +142,6 @@ def quantization_preserving(self):
146142
def fln_quantization(self):
147143
return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
148144

149-
def set_activation_quantization_params_fn(self, activation_quantization_params_fn:Callable):
150-
"""
151-
Sets activation params function for the node.
152-
153-
Args:
154-
activation_quantization_params_fn: Function for calculating activation params.
155-
156-
"""
157-
self.activation_quantization_params_fn = activation_quantization_params_fn
158-
159145
def set_activation_quantization_param(self,
160146
activation_params: dict):
161147
"""
@@ -182,8 +168,7 @@ def __eq__(self, other: Any) -> bool:
182168
if not isinstance(other, NodeActivationQuantizationConfig):
183169
return False # pragma: no cover
184170

185-
return self.activation_quantization_params_fn == other.activation_quantization_params_fn and \
186-
self.activation_error_method == other.activation_error_method and \
171+
return self.activation_error_method == other.activation_error_method and \
187172
self.activation_quantization_method == other.activation_quantization_method and \
188173
self.activation_n_bits == other.activation_n_bits and \
189174
self.quant_mode == other.quant_mode and \
@@ -197,8 +182,7 @@ def __eq__(self, other: Any) -> bool:
197182
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
198183

199184
def __hash__(self):
200-
return hash((self.activation_quantization_params_fn,
201-
self.activation_error_method,
185+
return hash((self.activation_error_method,
202186
self.activation_quantization_method,
203187
self.activation_n_bits,
204188
self.quant_mode,

model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,13 @@
1919
from mct_quantizers import QuantizationMethod
2020
from model_compression_toolkit.logger import Logger
2121
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import \
22-
lut_kmeans_tensor, lut_kmeans_histogram
22+
lut_kmeans_tensor
2323
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import \
24-
symmetric_selection_tensor, symmetric_selection_histogram
24+
symmetric_selection_tensor
2525
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import \
26-
uniform_selection_histogram, uniform_selection_tensor
26+
uniform_selection_tensor
2727
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import \
28-
power_of_two_selection_tensor, power_of_two_selection_histogram
29-
30-
31-
def get_activation_quantization_params_fn(activation_quantization_method: QuantizationMethod) -> Callable:
32-
"""
33-
Generate a function for finding activation quantization parameters.
34-
35-
Args:
36-
activation_quantization_method: Which quantization method to use for activations.
37-
Returns:
38-
A function to find the quantization parameters.
39-
40-
"""
41-
if activation_quantization_method == QuantizationMethod.POWER_OF_TWO:
42-
params_fn = power_of_two_selection_histogram
43-
elif activation_quantization_method == QuantizationMethod.SYMMETRIC:
44-
params_fn = symmetric_selection_histogram
45-
elif activation_quantization_method == QuantizationMethod.UNIFORM:
46-
params_fn = uniform_selection_histogram
47-
elif activation_quantization_method == QuantizationMethod.LUT_POT_QUANTIZER:
48-
params_fn = lut_kmeans_histogram
49-
else:
50-
Logger.critical(
51-
f"No parameter function found for the specified quantization method: {activation_quantization_method}") # pragma: no cover
52-
return params_fn
53-
28+
power_of_two_selection_tensor
5429

5530
weights_quant_params_fns = {
5631
QuantizationMethod.POWER_OF_TWO: power_of_two_selection_tensor,

model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import power_of_two_no_clipping_selection_min_max, \
16-
power_of_two_selection_histogram, power_of_two_selection_tensor
17-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import lut_kmeans_tensor
18-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import symmetric_no_clipping_selection_min_max
19-
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import uniform_no_clipping_selection_min_max
15+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.power_of_two_selection import (
16+
power_of_two_no_clipping_selection_min_max, power_of_two_selection_histogram, power_of_two_selection_tensor)
17+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.lut_kmeans_params import (
18+
lut_kmeans_tensor, lut_kmeans_histogram)
19+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.symmetric_selection import (
20+
symmetric_no_clipping_selection_min_max, symmetric_selection_histogram)
21+
from model_compression_toolkit.core.common.quantization.quantization_params_generation.uniform_selection import (
22+
uniform_no_clipping_selection_min_max, uniform_selection_histogram)
2023
from model_compression_toolkit.core.common.quantization.quantization_params_generation.outlier_filter import z_score_filter

0 commit comments

Comments
 (0)