|
39 | 39 | from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \ |
40 | 40 | FrameworkQuantizationCapabilities |
41 | 41 |
|
| 42 | + |
42 | 43 | def validate_graph_after_change(method: Callable) -> Callable: |
43 | 44 | """ |
44 | 45 | Decorator for graph-mutating methods. After the decorated method executes, |
@@ -120,28 +121,13 @@ def fusing_info(self) -> FusingInfo: |
120 | 121 | def fusing_info(self, fusing_info: FusingInfo): |
121 | 122 | self._fusing_info = fusing_info |
122 | 123 |
|
123 | | - def set_fqc(self, |
124 | | - fqc: FrameworkQuantizationCapabilities): |
| 124 | + def set_fqc(self, fqc: FrameworkQuantizationCapabilities): |
125 | 125 | """ |
126 | 126 | Set the graph's FQC. |
127 | 127 | Args: |
128 | 128 | fqc: FrameworkQuantizationCapabilities object. |
129 | 129 | """ |
130 | | - # validate graph nodes are either from the framework or a custom layer defined in the FQC |
131 | | - # Validate graph nodes are either built-in layers from the framework or custom layers defined in the FQC |
132 | | - fqc_layers = fqc.op_sets_to_layers.get_layers() |
133 | | - fqc_filtered_layers = [layer for layer in fqc_layers if isinstance(layer, LayerFilterParams)] |
134 | | - for n in self.nodes: |
135 | | - is_node_in_fqc = any([n.is_match_type(_type) for _type in fqc_layers]) or \ |
136 | | - any([n.is_match_filter_params(filtered_layer) for filtered_layer in fqc_filtered_layers]) |
137 | | - if n.is_custom: |
138 | | - if not is_node_in_fqc: |
139 | | - Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. ' |
140 | | - ' Please add the custom layer to Framework Quantization Capabilities (FQC), or file a feature ' |
141 | | - 'request or an issue if you believe this should be supported.') # pragma: no cover |
142 | | - if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(fqc).quantization_configurations]): |
143 | | - Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover |
144 | | - |
| 130 | + # TODO irena: this is only passed for negative shift activation. |
145 | 131 | self.fqc = fqc |
146 | 132 |
|
147 | 133 | def get_topo_sorted_nodes(self): |
@@ -578,7 +564,7 @@ def get_weights_configurable_nodes(self, |
578 | 564 | A list of nodes that their weights can be configured (namely, has one or more weight qc candidate). |
579 | 565 | """ |
580 | 566 | # configurability is only relevant for kernel attribute quantization |
581 | | - potential_conf_nodes = [n for n in list(self) if n.is_kernel_op] |
| 567 | + potential_conf_nodes = [n for n in self.nodes if n.kernel_attr] |
582 | 568 |
|
583 | 569 | def is_configurable(n): |
584 | 570 | return n.is_configurable_weight(n.kernel_attr) and (not n.reuse or include_reused_nodes) |
@@ -693,10 +679,8 @@ def get_final_weights_config(self) -> List[Tuple[BaseNode, int]]: |
693 | 679 | """ |
694 | 680 | Gets the final number of bits for quantization of each weights' configurable layer. |
695 | 681 |
|
696 | | - Args: |
697 | | - fw_info: fw_info: FrameworkInfo object with information about the specific framework's model. |
698 | | -
|
699 | | - Returns: A list of pairs of (node type, node's weights quantization bitwidth). |
| 682 | + Returns: |
| 683 | + A list of pairs of (node type, node's weights quantization bitwidth). |
700 | 684 |
|
701 | 685 | """ |
702 | 686 | sorted_conf_weights = self.get_sorted_weights_configurable_nodes() |
@@ -876,32 +860,36 @@ def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any) |
876 | 860 |
|
877 | 861 | return intermediate_nodes, next_node |
878 | 862 |
|
| 863 | + # TODO irena move to load_fqc and clean up tests (currently tests_pytest/common_tests/unit_tests/core/graph/test_base_graph.py) |
879 | 864 | def override_fused_node_activation_quantization_candidates(self): |
880 | 865 | """ |
881 | 866 | Override fused node activation quantization candidates for all nodes in fused operations, |
882 | 867 | except for the last node in each fused group. |
883 | 868 | Update the value of quantization_config with the value of op_quaitization_cfg from FusingInfo. |
884 | 869 | """ |
885 | | - from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import CandidateNodeQuantizationConfig |
886 | | - |
887 | 870 | nodes_in_fln = self.fusing_info.get_inner_fln_nodes() |
888 | 871 | for node in nodes_in_fln: |
889 | 872 | fused_node_op_id = self.fusing_info.get_fused_op_id_for_node(node.name) |
890 | | - fusiong_op_quaitization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id) |
891 | | - org_candidate = node.candidates_quantization_cfg[0] |
892 | | - if fusiong_op_quaitization_cfg is not None and fusiong_op_quaitization_cfg.enable_activation_quantization: |
893 | | - # Set ActivationQuantizationMode to FLN_QUANT and update the value of quantization_config |
894 | | - activation_quantization_cfg = NodeActivationQuantizationConfig(qc=org_candidate, |
895 | | - op_cfg=fusiong_op_quaitization_cfg, |
896 | | - activation_quantization_fn=org_candidate.activation_quantization_cfg.activation_quantization_fn, |
897 | | - activation_quantization_params_fn=org_candidate.activation_quantization_cfg.activation_quantization_params_fn) |
898 | | - activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT |
899 | | - for qc in node.candidates_quantization_cfg: |
900 | | - qc.activation_quantization_cfg = activation_quantization_cfg |
| 873 | + fusing_op_quantization_cfg = self.fusing_info.get_fused_op_quantization_config(fused_node_op_id) |
| 874 | + if fusing_op_quantization_cfg is not None and fusing_op_quantization_cfg.enable_activation_quantization: |
| 875 | + def update(qc): |
| 876 | + qc.activation_quantization_cfg = NodeActivationQuantizationConfig( |
| 877 | + fusing_op_quantization_cfg, |
| 878 | + qc.activation_quantization_cfg.activation_quantization_fn, |
| 879 | + qc.activation_quantization_cfg.activation_quantization_params_fn |
| 880 | + ) |
| 881 | + qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT |
| 882 | + node.quantization_cfg.update_all(update) |
| 883 | + node.quantization_cfg.remove_duplicates() |
901 | 884 | else: |
902 | | - # Set ActivationQuantizationMode to FLN_NO_QUANT |
| 885 | + node.quantization_cfg.update_activation_quantization_mode(ActivationQuantizationMode.FLN_NO_QUANT) |
| 886 | + # Remove duplicate candidates. We cannot compare whole candidates since activation configs might not |
| 887 | + # be identical, but we do want to treat them as such. So we only check duplication by weight configs. |
| 888 | + uniq_qcs = [] |
903 | 889 | for qc in node.candidates_quantization_cfg: |
904 | | - qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_NO_QUANT |
| 890 | + if not any(qc.weights_quantization_cfg == uqc.weights_quantization_cfg for uqc in uniq_qcs): |
| 891 | + uniq_qcs.append(qc) |
| 892 | + node.quantization_cfg.candidates_quantization_cfg = uniq_qcs |
905 | 893 |
|
906 | 894 | def validate(self): |
907 | 895 | """ |
|
0 commit comments