|
| 1 | +# Copyright (c) 2024 Intel Corporation |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +from typing import Dict, List, Optional, Tuple |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +import torch |
| 16 | +import torch.fx |
| 17 | +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node |
| 18 | + |
| 19 | +import nncf.torch.graph.operator_metatypes as om |
| 20 | +from nncf.common.graph import NNCFGraph |
| 21 | +from nncf.common.graph import NNCFNode |
| 22 | +from nncf.common.graph.definitions import NNCFGraphNodeType |
| 23 | +from nncf.common.graph.transformations.commands import TargetType |
| 24 | +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector |
| 25 | +from nncf.experimental.tensor import Tensor |
| 26 | +from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand |
| 27 | +from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder |
| 28 | +from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend |
| 29 | +from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand |
| 30 | +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand |
| 31 | +from nncf.torch.graph.transformations.commands import PTTargetPoint |
| 32 | +from nncf.torch.nncf_network import NNCFNetwork |
| 33 | +from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector |
| 34 | + |
| 35 | + |
| 36 | +class FXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): |
| 37 | + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { |
| 38 | + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, |
| 39 | + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, |
| 40 | + } |
| 41 | + |
| 42 | + @staticmethod |
| 43 | + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: |
| 44 | + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: |
| 45 | + port_id = None |
| 46 | + if target_type in FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: |
| 47 | + target_type = FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] |
| 48 | + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) |
| 49 | + |
| 50 | + @staticmethod |
| 51 | + def create_bias_correction_command( |
| 52 | + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph |
| 53 | + ) -> PTBiasCorrectionCommand: |
| 54 | + return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data)) |
| 55 | + |
| 56 | + @staticmethod |
| 57 | + def model_extraction_command( |
| 58 | + input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]] |
| 59 | + ) -> PTModelExtractionCommand: |
| 60 | + return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]]) |
| 61 | + |
| 62 | + @staticmethod |
| 63 | + def mean_statistic_collector( |
| 64 | + channel_axis: int, |
| 65 | + inplace: bool, |
| 66 | + num_samples: Optional[int] = None, |
| 67 | + window_size: Optional[int] = None, |
| 68 | + ) -> TensorCollector: |
| 69 | + return get_mean_statistic_collector(num_samples, channel_axis, window_size) |
| 70 | + |
| 71 | + @staticmethod |
| 72 | + def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: |
| 73 | + # Pytorch does not have name for extracted node |
| 74 | + return None, None |
| 75 | + |
| 76 | + @staticmethod |
| 77 | + def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: |
| 78 | + blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) |
| 79 | + for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): |
| 80 | + index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) |
| 81 | + blob[index] = data[j].data |
| 82 | + return blob |
| 83 | + |
| 84 | + @staticmethod |
| 85 | + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: |
| 86 | + # TODO: make a node_name_vs_node map to speed up the process |
| 87 | + from nncf.experimental.torch_fx.model_transformer import FXModelTransformer |
| 88 | + |
| 89 | + bias_node = nncf_graph.get_next_nodes(node)[0] |
| 90 | + graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name) |
| 91 | + return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model)) |
| 92 | + |
| 93 | + @staticmethod |
| 94 | + def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: |
| 95 | + return 0, 0 |
| 96 | + |
| 97 | + @staticmethod |
| 98 | + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: |
| 99 | + return Tensor(raw_data) |
| 100 | + |
| 101 | + @staticmethod |
| 102 | + def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: |
| 103 | + weight_node = nncf_graph.get_previous_nodes(node)[1] |
| 104 | + return weight_node.node_type == "dequantize_per_channel" |
| 105 | + |
| 106 | + @staticmethod |
| 107 | + def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: |
| 108 | + # Assumes that all biases were unfused |
| 109 | + if node.metatype in (om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype): |
| 110 | + next_nodes = nncf_graph.get_next_nodes(node) |
| 111 | + if len(next_nodes) != 1: |
| 112 | + return False |
| 113 | + return next_nodes[0].metatype in (om.PTAddMetatype,) |
| 114 | + |
| 115 | + @staticmethod |
| 116 | + def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: |
| 117 | + return node.node_name, node.node_name |
0 commit comments