diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index 5c33662153e..14f6eee7520 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -65,6 +65,10 @@ def layer_name(self) -> LayerName: def layer_attributes(self) -> BaseLayerAttributes: return self.data.get(NNCFGraph.LAYER_ATTRIBUTES) + @layer_attributes.setter + def layer_attributes(self, data: Any) -> None: + self.data[NNCFGraph.LAYER_ATTRIBUTES] = data + @property def ignored_algorithms(self) -> List[str]: return self.data.get(NNCFGraph.IGNORED_ALGOS_ATTR, []) diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index a46be241d41..79f168d1c2b 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -13,8 +13,7 @@ from abc import ABC from abc import abstractmethod from enum import Enum -from typing import List -from typing import Tuple +from typing import List, Tuple, Any class Dtype(Enum): @@ -29,15 +28,30 @@ class BaseLayerAttributes(ABC): """ +class MultipleInputLayerAttributes(BaseLayerAttributes): + """ + Represents a layer with multiple inputs. + """ + + def __init__(self, + axis: int): + self.axis = axis + + def __eq__(self, other: Any): + return isinstance(other, MultipleInputLayerAttributes) \ + and self.axis == other.axis + + class WeightedLayerAttributes(BaseLayerAttributes): """ Represents a layer with weights. """ + def __init__(self, weight_requires_grad: bool, dtype: Dtype = Dtype.FLOAT): self.weight_requires_grad = weight_requires_grad self.dtype = dtype - def __eq__(self, other): + def __eq__(self, other: Any): return isinstance(other, WeightedLayerAttributes) \ and self.weight_requires_grad == other.weight_requires_grad @@ -59,6 +73,7 @@ class GenericWeightedLayerAttributes(WeightedLayerAttributes): Represents a weighted layer for which there is no information ahead of time of the exact meaning of the weight indices. """ + def __init__(self, weight_requires_grad: bool, weight_shape: List[int], filter_dimension_idx: int = 0): super().__init__(weight_requires_grad) @@ -112,7 +127,7 @@ def __init__(self, self.transpose = transpose self.padding_values = padding_values - def __eq__(self, other): + def __eq__(self, other: Any): return isinstance(other, ConvolutionLayerAttributes) \ and super().__eq__(other) \ and self.in_channels == other.in_channels \ @@ -148,7 +163,7 @@ def __init__(self, self.num_channels = num_channels self.num_groups = num_groups - def __eq__(self, other): + def __eq__(self, other: Any): return isinstance(other, GroupNormLayerAttributes) \ and super().__eq__(other) \ and self.num_channels == other.num_channels \ diff --git a/nncf/common/graph/utils.py b/nncf/common/graph/utils.py new file mode 100644 index 00000000000..23a8cee862a --- /dev/null +++ b/nncf/common/graph/utils.py @@ -0,0 +1,43 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import List + +from nncf.common.utils.logger import logger + + +def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int]]) -> int: + """ + Returns concatenation axis by given input and output shape of concat node. + + :param input_shapes: Input_shapes of given concat node. + :param output_shapes: Input_shapes of given concat node. + :returns: Concatenation axis of given concat node. + """ + axis = None + none_dim = None + for idx, (dim_in, dim_out) in enumerate(zip(input_shapes[0], output_shapes[0])): + if dim_in != dim_out: + axis = idx + break + if dim_in is None: + none_dim = idx + + if axis is None: + if none_dim is None: + axis = -1 + logger.warning('Identity concat node detected') + else: + axis = none_dim + + return axis diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py deleted file mode 100644 index 1187def3e77..00000000000 --- a/nncf/common/pruning/export_helpers.py +++ /dev/null @@ -1,52 +0,0 @@ -""" - Copyright (c) 2021 Intel Corporation - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -from nncf.common.graph import NNCFGraph -from nncf.common.graph import NNCFNode - - -class DefaultMetaOp: - """ - Determines meta operations which aggregate operations having common - properties of interaction with pruning masks - """ - - subtypes = [] - additional_types = [] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - """ - :return: accept_pruned_input - can this operation work with pruned input or not - """ - raise NotImplementedError - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - """ - Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any). - - :param node: The graph node to propagate mask through it - :param graph: The model graph to prune - """ - raise NotImplementedError - - @classmethod - def get_all_op_aliases(cls): - """ - :return: list of all aliases of types in metatype - """ - op_types = [] - for subtype in cls.subtypes: - op_types.extend(subtype.get_all_aliases()) - op_types = list(set(op_types)) + cls.additional_types - return op_types diff --git a/nncf/common/pruning/mask_propagation.py b/nncf/common/pruning/mask_propagation.py index aa1a13fe879..614a0c0ce8a 100644 --- a/nncf/common/pruning/mask_propagation.py +++ b/nncf/common/pruning/mask_propagation.py @@ -10,16 +10,10 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import List -from typing import Union -from typing import TypeVar from nncf.common.graph import NNCFGraph -from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry - -TensorType = TypeVar('TensorType') +from nncf.common.pruning.operations import BasePruningOp class MaskPropagationAlgorithm: @@ -40,7 +34,7 @@ def __init__(self, graph: NNCFGraph, pruning_operator_metatypes: PruningOperatio self._graph = graph self._pruning_operator_metatypes = pruning_operator_metatypes - def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp: + def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp: """ Returns class of metaop that corresponds to `type_name` type. @@ -60,26 +54,3 @@ def mask_propagation(self): for node in self._graph.topological_sort(): cls = self.get_meta_operation_by_type_name(node.node_type) cls.mask_propagation(node, self._graph) - - -def get_input_masks(node: NNCFNode, graph: NNCFGraph) -> List[Union[TensorType, None]]: - """ - Returns input masks for all inputs of nx_node. - - :return: Input masks. - """ - input_masks = [input_node.data['output_mask'] for input_node in graph.get_previous_nodes(node)] - return input_masks - - -def identity_mask_propagation(node: NNCFNode, graph: NNCFGraph): - """ - Propagates input mask through nx_node. - """ - input_masks = get_input_masks(node, graph) - if not input_masks: - # In case for disconnected NNCFGraph - input_masks = [None] - assert len(input_masks) == 1 - node.data['input_masks'] = input_masks - node.data['output_mask'] = input_masks[0] diff --git a/nncf/common/pruning/model_analysis.py b/nncf/common/pruning/model_analysis.py index 4df7f4870a4..a24401c220f 100644 --- a/nncf/common/pruning/model_analysis.py +++ b/nncf/common/pruning/model_analysis.py @@ -17,7 +17,7 @@ from nncf.common.graph import NNCFNode from nncf.common.pruning.clusterization import Cluster from nncf.common.pruning.clusterization import Clusterization -from nncf.common.pruning.export_helpers import DefaultMetaOp +from nncf.common.pruning.operations import BasePruningOp from nncf.common.pruning.utils import find_next_nodes_not_of_types from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry @@ -143,7 +143,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool: """ return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases() - def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp: + def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp: """ Returns class of metaop that corresponds to `type_name` type. diff --git a/nncf/common/pruning/operations.py b/nncf/common/pruning/operations.py new file mode 100644 index 00000000000..27393821ca5 --- /dev/null +++ b/nncf/common/pruning/operations.py @@ -0,0 +1,228 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Optional, List + +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.pruning.utils import is_grouped_conv +from nncf.common.pruning.utils import is_depthwise_conv +from nncf.common.pruning.utils import get_input_masks +from nncf.common.pruning.utils import identity_mask_propagation +from nncf.common.tensor import NNCFTensor +from nncf.common.graph.layer_attributes import GroupNormLayerAttributes + + +class BasePruningOp: + """ + Determines meta operations which aggregate operations having common + properties of interaction with pruning masks + """ + + subtypes = [] + additional_types = [] + + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + """ + :return: accept_pruned_input - can this operation work with pruned input or not + """ + raise NotImplementedError + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + """ + Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any). + + :param node: The graph node to propagate mask through it + :param graph: The model graph to prune + """ + raise NotImplementedError + + @classmethod + def get_all_op_aliases(cls) -> List[str]: + """ + :return: list of all aliases of types in metatype + """ + op_types = [] + for subtype in cls.subtypes: + op_types.extend(subtype.get_all_aliases()) + op_types = list(set(op_types)) + cls.additional_types + return op_types + + +class InputPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + return False + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + node.data['output_mask'] = None + + +class OutputPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + node.data['output_mask'] = None + + +class IdentityMaskForwardPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + identity_mask_propagation(node, graph) + + +class ConvolutionPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + accept_pruned_input = True + if is_grouped_conv(node): + if not is_depthwise_conv(node): + accept_pruned_input = False + return accept_pruned_input + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + input_masks = get_input_masks(node, graph) + output_mask = node.data.get('output_mask', None) + + if is_grouped_conv(node): + output_mask = None + if is_depthwise_conv(node): + output_mask = input_masks[0] + + node.data['output_mask'] = output_mask + + +class TransposeConvolutionPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + accept_pruned_input = True + if is_grouped_conv(node): + if not is_depthwise_conv(node): + accept_pruned_input = False + return accept_pruned_input + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + input_masks = get_input_masks(node, graph) + output_mask = node.data.get('output_mask', None) + + # In case of group convs we can't prune by output filters + if is_grouped_conv(node): + output_mask = None + if is_depthwise_conv(node): + output_mask = input_masks[0] + + node.data['output_mask'] = output_mask + + +class BatchNormPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + identity_mask_propagation(node, graph) + + +class GroupNormPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + # For Instance Normalization + return isinstance(node.layer_attributes, GroupNormLayerAttributes) \ + and node.layer_attributes.num_groups == node.layer_attributes.num_channels + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + if cls.accept_pruned_input(node): + identity_mask_propagation(node, graph) + else: + node.data['output_mask'] = None + + +class ConcatPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Optional[NNCFTensor]: + """ + Generate output mask from input masks with all None replaced by identity masks. + If all input masks is None return None. + + :param node: Node to determine it's sources. + :param graph: NNCF graph to work with. + :return: Filled input masks. + """ + input_edges = graph.get_input_edges(node) + previous_nodes = [edge.from_node for edge in input_edges] + input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] + + not_empty_masks = [mask for mask in input_masks if mask is not None] # type: List[NNCFTensor] + if not not_empty_masks: + return None + + first_non_empty_mask = not_empty_masks[0] + tensor_processor = first_non_empty_mask.tensor_processor + device = first_non_empty_mask.device + filled_input_masks = [] + for i, mask in enumerate(input_masks): + if mask is None: + concat_axis = node.layer_attributes.axis + concat_dim = input_edges[i].tensor_shape[concat_axis] + mask = tensor_processor.ones(concat_dim, device) + filled_input_masks.append(mask) + result_mask = tensor_processor.concatenate(filled_input_masks, 0) + return result_mask + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + result_mask = cls.generate_output_mask(node, graph) + node.data['output_mask'] = result_mask + + +class ElementwisePruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + input_masks = get_input_masks(node, graph) + + node.data['input_masks'] = input_masks # type: List[NNCFTensor] + if input_masks[0] is not None: + input_masks[0].tensor_processor.check_all_close(input_masks) + node.data['output_mask'] = input_masks[0] + + +class StopMaskForwardPruningOp(BasePruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode) -> bool: + return False + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph) -> None: + node.data['output_mask'] = None diff --git a/nncf/common/pruning/utils.py b/nncf/common/pruning/utils.py index b2041cad89b..7cef39a5a06 100644 --- a/nncf/common/pruning/utils.py +++ b/nncf/common/pruning/utils.py @@ -24,6 +24,7 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph import NNCFNodeName +from nncf.common.tensor import NNCFTensor from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.pruning.clusterization import Cluster @@ -407,3 +408,31 @@ def is_conv_with_downsampling(node: NNCFNode) -> bool: return not np.all(np.array(layer_attrs.stride) == 1) \ and not layer_attrs.transpose return False + + +def get_input_masks(node: NNCFNode, graph: NNCFGraph) -> List[Optional[NNCFTensor]]: + """ + Returns input masks for all inputs of given NNCFNode. + + :param node: Given NNCFNode. + :param graph: Graph to work with. + :return: Input masks. + """ + input_masks = [input_node.data['output_mask'] for input_node in graph.get_previous_nodes(node)] + return input_masks + + +def identity_mask_propagation(node: NNCFNode, graph: NNCFGraph) -> None: + """ + Propagates input mask through NNCFNode. + + :param node: Graph node to perform identity mask propagation on. + :param graph: Graph to work with. + """ + input_masks = get_input_masks(node, graph) + if not input_masks: + # In case for disconnected NNCFGraph + input_masks = [None] + assert len(input_masks) == 1 + node.data['input_masks'] = input_masks + node.data['output_mask'] = input_masks[0] diff --git a/nncf/common/tensor.py b/nncf/common/tensor.py new file mode 100644 index 00000000000..403f98522b0 --- /dev/null +++ b/nncf/common/tensor.py @@ -0,0 +1,79 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from abc import abstractmethod +from typing import TypeVar, List + +TensorType = TypeVar('TensorType') +DeviceType = TypeVar('DeviceType') + + +class NNCFTensor: + """ + An interface of framework specific tensors for common NNCF algorithms. + """ + + def __init__(self, tensor: TensorType, + tensor_processor: 'NNCFBaseTensorProcessor'): + self._tensor = tensor + self._tensor_processor = tensor_processor + + @property + def tensor(self) -> TensorType: + return self._tensor + + @property + def tensor_processor(self) -> 'NNCFBaseTensorProcessor': + return self._tensor_processor + + @property + @abstractmethod + def device(self) -> DeviceType: + pass + + +class NNCFBaseTensorProcessor: + """ + An interface of the processing methods set for NNCFTensors. + """ + + @classmethod + @abstractmethod + def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor: + """ + Join a list of NNCFTensors along an existing axis. + + :param tensors: List of NNCFTensors. + :param axis: The axis along which the tensors will be joined. + :returns: The concatenated List of the tensors. + """ + + @classmethod + @abstractmethod + def ones(cls, shape: List[int], device: DeviceType) -> NNCFTensor: + """ + Return a new float tensor of given shape, filled with ones. + + :param shape: Shape of the new tensor. + :param device: Device to put created tensor in. + :returns: Float tensor of ones with the given shape. + """ + + @classmethod + @abstractmethod + def check_all_close(cls, tensors: List[NNCFTensor]) -> None: + """ + Raises an AssertionError if two objects are not equal. + + :param tensors: List of tensors to check pairwise equality. + """ diff --git a/nncf/tensorflow/graph/converter.py b/nncf/tensorflow/graph/converter.py index d70ae44d857..b5aba4af7d9 100644 --- a/nncf/tensorflow/graph/converter.py +++ b/nncf/tensorflow/graph/converter.py @@ -27,10 +27,13 @@ from nncf.common.graph import NNCFNodeName from nncf.common.graph import OperatorMetatype from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.utils import get_concat_axis from nncf.common.utils.logger import logger as nncf_logger from nncf.tensorflow.graph.metatypes.common import DECONV_LAYER_METATYPES from nncf.tensorflow.graph.metatypes.common import DEPTHWISE_CONV_LAYER_METATYPES +from nncf.tensorflow.graph.metatypes.common import LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS from nncf.tensorflow.graph.metatypes.common import GENERAL_CONV_LAYER_METATYPES from nncf.tensorflow.graph.metatypes.matcher import get_keras_layer_metatype from nncf.tensorflow.graph.metatypes.matcher import get_op_metatype @@ -514,6 +517,8 @@ def convert(self) -> NNCFGraph: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=True) elif metatype in GENERAL_CONV_LAYER_METATYPES: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=False) + elif metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + layer_attributes = _get_multiple_input_layer_attributes(layer) else: layer_attributes = None is_shared = len(self._layer_name_to_node_names[layer_name]) > 1 @@ -604,6 +609,8 @@ def convert(self) -> NNCFGraph: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=True) elif layer_metatype in GENERAL_CONV_LAYER_METATYPES: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=False) + elif layer_metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + layer_attributes = _get_multiple_input_layer_attributes(model_layer) if layer_attributes is not None: attrs.update({NNCFGraph.LAYER_ATTRIBUTES: layer_attributes}) @@ -651,6 +658,16 @@ def convert(self) -> NNCFGraph: return nncf_graph +def _get_multiple_input_layer_attributes(layer: tf.keras.layers.Layer) -> MultipleInputLayerAttributes: + if hasattr(layer, 'axis'): + axis = layer.axis + else: + input_shape = layer.input_shape + output_shape = layer.output_shape + axis = get_concat_axis(input_shape, output_shape) + return MultipleInputLayerAttributes(axis) + + def _get_conv_layer_attributes(layer: tf.keras.layers.Layer, is_depthwise: bool = False) -> ConvolutionLayerAttributes: channel_axis = get_input_channel_axis(layer) layer_ = unwrap_layer(layer) diff --git a/nncf/tensorflow/pruning/base_algorithm.py b/nncf/tensorflow/pruning/base_algorithm.py index 2e354cb28da..8ff91d1d39f 100644 --- a/nncf/tensorflow/pruning/base_algorithm.py +++ b/nncf/tensorflow/pruning/base_algorithm.py @@ -40,9 +40,10 @@ from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout from nncf.tensorflow.graph.utils import get_layer_identifier from nncf.tensorflow.graph.utils import collect_wrapped_layers -from nncf.tensorflow.pruning.export_helpers import TFElementwise -from nncf.tensorflow.pruning.export_helpers import TFIdentityMaskForwardOps -from nncf.tensorflow.pruning.export_helpers import TF_PRUNING_OPERATOR_METATYPES +from nncf.tensorflow.tensor import TFNNCFTensor +from nncf.tensorflow.pruning.operations import TFElementwisePruningOp +from nncf.tensorflow.pruning.operations import TFIdentityMaskForwardPruningOp +from nncf.tensorflow.pruning.operations import TF_PRUNING_OPERATOR_METATYPES from nncf.tensorflow.pruning.utils import get_filter_axis from nncf.tensorflow.pruning.utils import get_filters_num from nncf.tensorflow.sparsity.magnitude.operation import BinaryMask @@ -125,7 +126,7 @@ def get_transformation_layout(self, model: tf.keras.Model) -> TFTransformationLa # Add output_mask to elements to run mask_propagation # and detect spec_nodes that will be pruned. # It should be done for all elements of shared layer. - node.data['output_mask'] = tf.ones(node.layer_attributes.out_channels) + node.data['output_mask'] = TFNNCFTensor(tf.ones(node.layer_attributes.out_channels)) if layer_name in shared_layers: continue if node.is_shared(): @@ -207,7 +208,7 @@ def _get_insertion_command_binary_mask(self, layer_name: str, @staticmethod def _get_bn_for_node(node: NNCFNode, bn_nodes: List[NNCFNode]) -> Tuple[bool, List[NNCFNode]]: is_finished = False - propagating_ops = [op_name for meta_op in [TFIdentityMaskForwardOps, TFElementwise] + propagating_ops = [op_name for meta_op in [TFIdentityMaskForwardPruningOp, TFElementwisePruningOp] for op_name in meta_op.get_all_op_aliases()] if node.node_type == 'BatchNormalization': is_finished = True diff --git a/nncf/tensorflow/pruning/export_helpers.py b/nncf/tensorflow/pruning/export_helpers.py deleted file mode 100644 index fac53e8fbda..00000000000 --- a/nncf/tensorflow/pruning/export_helpers.py +++ /dev/null @@ -1,240 +0,0 @@ -""" - Copyright (c) 2021 Intel Corporation - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -from typing import Dict -from typing import List -from typing import Union - -import tensorflow as tf - -from nncf.common.pruning.utils import is_depthwise_conv -from nncf.tensorflow.graph.pattern_operations import KERAS_ACTIVATIONS_OPERATIONS -from nncf.tensorflow.graph.pattern_operations import ELEMENTWISE_OPERATIONS -from nncf.tensorflow.graph.pattern_operations import TF_ACTIVATIONS_OPERATIONS -from nncf.common.graph.definitions import NNCFGraphNodeType -from nncf.common.graph import NNCFGraph -from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp -from nncf.common.pruning.mask_propagation import identity_mask_propagation -from nncf.common.pruning.mask_propagation import get_input_masks -from nncf.common.pruning.utils import get_sources_of_node -from nncf.common.pruning.utils import is_grouped_conv -from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry - -TF_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") - - -def _get_types(operations_dict: Dict) -> List[str]: - return operations_dict['type'] - - -@TF_PRUNING_OPERATOR_METATYPES.register('model_input') -class TFInput(DefaultMetaOp): - additional_types = ['InputLayer', NNCFGraphNodeType.INPUT_NODE] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None - -@TF_PRUNING_OPERATOR_METATYPES.register('model_output') -class TFOutput(DefaultMetaOp): - additional_types = [NNCFGraphNodeType.OUTPUT_NODE] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None - -@TF_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class TFIdentityMaskForwardOps(DefaultMetaOp): - additional_types = _get_types(KERAS_ACTIVATIONS_OPERATIONS) + _get_types(TF_ACTIVATIONS_OPERATIONS) \ - + ['AvgPool2D', 'GlobalAvgPool2D', 'AveragePooling2D', 'GlobalAveragePooling2D'] \ - + ['MaxPooling2D', 'GlobalMaxPooling2D', 'MaxPool2D', 'GlobalMaxPool2D'] \ - + ['Dropout', 'ZeroPadding2D', 'Identity', 'Pad', 'UpSampling2D'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - - -@TF_PRUNING_OPERATOR_METATYPES.register('convolution') -class TFConvolution(DefaultMetaOp): - additional_types = ['Conv1D', 'Conv2D', 'Conv3D', 'DepthwiseConv2D'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['output_mask'] = output_mask - - -@TF_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class TFTransposeConvolution(DefaultMetaOp): - additional_types = ['Conv1DTranspose', 'Conv2DTranspose', 'Conv3DTranspose'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['output_mask'] = output_mask - - -@TF_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class TFBatchNorm(DefaultMetaOp): - additional_types = ['BatchNormalization', 'SyncBatchNormalization'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - - -@TF_PRUNING_OPERATOR_METATYPES.register('concat') -class TFConcat(DefaultMetaOp): - additional_types = ['Concatenate', 'ConcatV2'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: - """ - Return whether all input sources of node is convolutions or not. - - :param node: Node to determine it's sources - :param graph: NNCF graph to work with - :return: True if all input sources of node is convolutions - """ - - for input_node in graph.get_previous_nodes(node): - # If input has mask -> it went from convolution (source of this node is a convolution) - if input_node.data.get('output_mask', None) is None: - continue - - source_nodes = get_sources_of_node(input_node, graph, TFConvolution.get_all_op_aliases() + - TFStopMaskForwardOps.get_all_op_aliases() + - TFInput.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] - if any(t in sources_types for t in TFStopMaskForwardOps.get_all_op_aliases()): - return False - return True - - @classmethod - def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[tf.Tensor, None]: - """ - Generate output mask from input masks with all None replaced by identity masks. - If all input masks is None return None. - - :param node: Node to determine it's sources - :param graph: NNCF graph to work with - :return: Output mask - """ - input_edges = graph.get_input_edges(node) - previous_nodes = [edge.from_node for edge in input_edges] - input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] - - if all(mask is None for mask in input_masks): - return None - - device = [m for m in input_masks if m is not None][0].device - - filled_input_masks = [] - for i, mask in enumerate(input_masks): - if mask is None: - with tf.device(device): - mask = tf.ones(input_edges[i].tensor_shape[-1]) - filled_input_masks.append(mask) - result_mask = tf.concat(filled_input_masks, 0) - return result_mask - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - result_mask = None - - if cls.check_concat(node, graph): - result_mask = cls.generate_output_mask(node, graph) - - node.data['output_mask'] = result_mask - - -@TF_PRUNING_OPERATOR_METATYPES.register('elementwise') -class TFElementwise(DefaultMetaOp): - additional_types = _get_types(ELEMENTWISE_OPERATIONS) - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - if input_masks[0] is not None: - for input_mask in input_masks[1:]: - tf.debugging.assert_near(input_masks[0], input_mask) - node.data['output_mask'] = input_masks[0] - - -@TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class TFStopMaskForwardOps(DefaultMetaOp): - additional_types = ['Dense', 'MatMul'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None diff --git a/nncf/tensorflow/pruning/filter_pruning/algorithm.py b/nncf/tensorflow/pruning/filter_pruning/algorithm.py index af5e775f85d..c542cf672f1 100644 --- a/nncf/tensorflow/pruning/filter_pruning/algorithm.py +++ b/nncf/tensorflow/pruning/filter_pruning/algorithm.py @@ -47,16 +47,17 @@ from nncf.tensorflow.graph.utils import collect_wrapped_layers from nncf.tensorflow.graph.utils import get_original_name_and_instance_idx from nncf.tensorflow.graph.utils import unwrap_layer +from nncf.tensorflow.tensor import TFNNCFTensor from nncf.tensorflow.layers.data_layout import get_input_channel_axis from nncf.tensorflow.layers.wrapper import NNCFWrapper from nncf.tensorflow.loss import TFZeroCompressionLoss from nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoBuilder from nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoController from nncf.tensorflow.pruning.base_algorithm import PrunedLayerInfo -from nncf.tensorflow.pruning.export_helpers import TF_PRUNING_OPERATOR_METATYPES -from nncf.tensorflow.pruning.export_helpers import TFConvolution -from nncf.tensorflow.pruning.export_helpers import TFElementwise -from nncf.tensorflow.pruning.export_helpers import TFTransposeConvolution +from nncf.tensorflow.pruning.operations import TF_PRUNING_OPERATOR_METATYPES +from nncf.tensorflow.pruning.operations import TFConvolutionPruningOp +from nncf.tensorflow.pruning.operations import TFElementwisePruningOp +from nncf.tensorflow.pruning.operations import TFTransposeConvolutionPruningOp from nncf.tensorflow.pruning.filter_pruning.functions import calculate_binary_mask from nncf.tensorflow.pruning.filter_pruning.functions import FILTER_IMPORTANCE_FUNCTIONS from nncf.tensorflow.pruning.filter_pruning.functions import tensor_l2_normalizer @@ -87,11 +88,11 @@ def _is_pruned_layer(self, layer: tf.keras.layers.Layer) -> bool: return layer.__class__.__name__ in self._prunable_types def _get_op_types_of_pruned_layers(self) -> List[str]: - return [op_name for meta_op in [TFConvolution, TFTransposeConvolution] + return [op_name for meta_op in [TFConvolutionPruningOp, TFTransposeConvolutionPruningOp] for op_name in meta_op.get_all_op_aliases()] def _get_types_of_grouping_ops(self) -> List[str]: - return TFElementwise.get_all_op_aliases() + return TFElementwisePruningOp.get_all_op_aliases() @ADAPTIVE_COMPRESSION_CONTROLLERS.register('tf_filter_pruning') @@ -292,7 +293,7 @@ def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float): filter_mask = calculate_binary_mask(cumulative_filters_importance, threshold) for node in group.elements: nncf_node = self._original_graph.get_node_by_id(node.nncf_node_id) - nncf_node.data['output_mask'] = filter_mask + nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask) # 2. Propagating masks across the graph mask_propagator = MaskPropagationAlgorithm(self._original_graph, TF_PRUNING_OPERATOR_METATYPES) @@ -304,7 +305,7 @@ def _set_binary_masks_for_pruned_layers_groupwise(self, pruning_rate: float): nncf_node = [n for n in nncf_sorted_nodes if layer.name == n.layer_name][0] if nncf_node.data['output_mask'] is not None: - self._set_operation_masks([layer], nncf_node.data['output_mask']) + self._set_operation_masks([layer], nncf_node.data['output_mask'].tensor) # Calculate actual flops and weights number with new masks self._update_benchmark_statistics() @@ -339,7 +340,7 @@ def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float): filter_mask = calculate_binary_mask(filter_importances[group.id], threshold) for node in group.elements: nncf_node = self._original_graph.get_node_by_id(node.nncf_node_id) - nncf_node.data['output_mask'] = filter_mask + nncf_node.data['output_mask'] = TFNNCFTensor(filter_mask) # 2. Propagate masks across the graph mask_propagator = MaskPropagationAlgorithm(self._original_graph, TF_PRUNING_OPERATOR_METATYPES) @@ -351,7 +352,7 @@ def _set_binary_masks_for_pruned_layers_globally(self, pruning_rate: float): nncf_node = [n for n in nncf_sorted_nodes if layer.name == n.layer_name][0] if nncf_node.data['output_mask'] is not None: - self._set_operation_masks([layer], nncf_node.data['output_mask']) + self._set_operation_masks([layer], nncf_node.data['output_mask'].tensor) # Calculate actual flops with new masks self._update_benchmark_statistics() @@ -371,7 +372,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, for layer in wrapped_layers: nncf_node = [n for n in nncf_sorted_nodes if layer.name == n.layer_name][0] - nncf_node.data['output_mask'] = tf.ones(get_filters_num(layer)) + nncf_node.data['output_mask'] = TFNNCFTensor(tf.ones(get_filters_num(layer))) # 1. Calculate importances for all groups of filters. Initialize masks. filter_importances = [] @@ -415,7 +416,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, for group in self._pruned_layer_groups_info.get_all_clusters(): for node in group.elements: nncf_node = self._original_graph.get_node_by_id(node.nncf_node_id) - nncf_node.data['output_mask'] = masks[group.id] + nncf_node.data['output_mask'] = TFNNCFTensor(masks[group.id]) mask_propagator = MaskPropagationAlgorithm(self._original_graph, TF_PRUNING_OPERATOR_METATYPES) mask_propagator.mask_propagation() @@ -428,7 +429,7 @@ def _set_binary_masks_for_pruned_modules_globally_by_flops_target(self, nncf_node = [n for n in nncf_sorted_nodes if layer.name == n.layer_name][0] if nncf_node.data['output_mask'] is not None: - self._set_operation_masks([layer], nncf_node.data['output_mask']) + self._set_operation_masks([layer], nncf_node.data['output_mask'].tensor) return raise RuntimeError(f'Unable to prune model to required flops pruning rate:' f' {target_flops_pruning_rate}') @@ -511,7 +512,7 @@ def _calculate_flops_and_weights_pruned_model_by_masks(self): for group in self._pruned_layer_groups_info.get_all_clusters(): assert all(tmp_out_channels[group.elements[0].node_name] == tmp_out_channels[node.node_name] for node in group.elements) - mask = self._original_graph.get_node_by_id(group.elements[0].nncf_node_id).data['output_mask'] + mask = self._original_graph.get_node_by_id(group.elements[0].nncf_node_id).data['output_mask'].tensor new_out_channels_num = int(sum(mask)) num_of_sparse_elems = len(mask) - new_out_channels_num for node in group.elements: diff --git a/nncf/tensorflow/pruning/operations.py b/nncf/tensorflow/pruning/operations.py new file mode 100644 index 00000000000..c17910a42d0 --- /dev/null +++ b/nncf/tensorflow/pruning/operations.py @@ -0,0 +1,86 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Dict +from typing import List + +from nncf.tensorflow.graph.pattern_operations import KERAS_ACTIVATIONS_OPERATIONS +from nncf.tensorflow.graph.pattern_operations import ELEMENTWISE_OPERATIONS +from nncf.tensorflow.graph.pattern_operations import TF_ACTIVATIONS_OPERATIONS +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.operations import ( + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp +) + +TF_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") + + +def _get_types(operations_dict: Dict) -> List[str]: + return operations_dict['type'] + + +@TF_PRUNING_OPERATOR_METATYPES.register('model_input') +class TFInputPruningOp(InputPruningOp): + additional_types = ['InputLayer', NNCFGraphNodeType.INPUT_NODE] + + +@TF_PRUNING_OPERATOR_METATYPES.register('model_output') +class TFOutputPruningOp(OutputPruningOp): + additional_types = [NNCFGraphNodeType.OUTPUT_NODE] + + +@TF_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') +class TFIdentityMaskForwardPruningOp(IdentityMaskForwardPruningOp): + additional_types = _get_types(KERAS_ACTIVATIONS_OPERATIONS) + _get_types(TF_ACTIVATIONS_OPERATIONS) \ + + ['AvgPool2D', 'GlobalAvgPool2D', 'AveragePooling2D', 'GlobalAveragePooling2D'] \ + + ['MaxPooling2D', 'GlobalMaxPooling2D', 'MaxPool2D', 'GlobalMaxPool2D'] \ + + ['Dropout', 'ZeroPadding2D', 'Identity', 'Pad', 'UpSampling2D'] + + +@TF_PRUNING_OPERATOR_METATYPES.register('convolution') +class TFConvolutionPruningOp(ConvolutionPruningOp): + additional_types = ['Conv1D', 'Conv2D', 'Conv3D', 'DepthwiseConv2D'] + + +@TF_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') +class TFTransposeConvolutionPruningOp(TransposeConvolutionPruningOp): + additional_types = ['Conv1DTranspose', 'Conv2DTranspose', 'Conv3DTranspose'] + + +@TF_PRUNING_OPERATOR_METATYPES.register('batch_norm') +class TFBatchNormPruningOp(BatchNormPruningOp): + additional_types = ['BatchNormalization', 'SyncBatchNormalization'] + + +@TF_PRUNING_OPERATOR_METATYPES.register('elementwise') +class TFElementwisePruningOp(ElementwisePruningOp): + additional_types = _get_types(ELEMENTWISE_OPERATIONS) + + +@TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') +class TFStopMaskForwardPruningOp(StopMaskForwardPruningOp): + additional_types = ['Dense', 'MatMul'] + + +@TF_PRUNING_OPERATOR_METATYPES.register('concat') +class TFConcatPruningOp(ConcatPruningOp): + additional_types = ['Concatenate', 'ConcatV2'] diff --git a/nncf/tensorflow/tensor.py b/nncf/tensorflow/tensor.py new file mode 100644 index 00000000000..f20ee420d47 --- /dev/null +++ b/nncf/tensorflow/tensor.py @@ -0,0 +1,58 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import tensorflow as tf + +from typing import List + +from nncf.common.tensor import NNCFTensor +from nncf.common.tensor import NNCFBaseTensorProcessor + + +class TFNNCFTensorProcessor(NNCFBaseTensorProcessor): + """ + A realization of the processing methods set for TFNNCFTensors. + """ + + @classmethod + def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor: + # pylint: disable=E1120,E1123 + ret_tensor = tf.concat([t.tensor for t in tensors], axis=axis) + return TFNNCFTensor(ret_tensor) + + @classmethod + def ones(cls, shape: List[int], device) -> NNCFTensor: + return TFNNCFTensor(tf.ones(shape)) + + @classmethod + def check_all_close(cls, tensors: List[NNCFTensor]) -> None: + for input_mask in tensors[1:]: + tf.debugging.assert_near(tensors[0].tensor, input_mask.tensor) + + +class TFNNCFTensor(NNCFTensor): + """ + A realisation of tensorflow tensors wrapper for common NNCF algorithms. + """ + + def __init__(self, tensor: tf.Variable): + # In case somebody attempts to wrap + # tensor twice + if isinstance(tensor, self.__class__): + tensor = tensor.tensor + + super().__init__(tensor, TFNNCFTensorProcessor) + + @property + def device(self) -> tf.device: + return self._tensor.device diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 238020122b6..1b0e02f2ea1 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -55,7 +55,7 @@ def ignore_scope(cls): return cls -OP_NAMES_REQUIRING_MODULE_ATTRS = [v.op_func_name for v in NNCF_MODULES_DICT] + ["group_norm"] +OP_NAMES_REQUIRING_MODULE_ATTRS = [v.op_func_name for v in NNCF_MODULES_DICT] + ['group_norm'] def wrap_operator(operator, operator_info: 'PatchedOperatorInfo'): diff --git a/nncf/torch/graph/graph_builder.py b/nncf/torch/graph/graph_builder.py index e9ea4794221..9f60f354276 100644 --- a/nncf/torch/graph/graph_builder.py +++ b/nncf/torch/graph/graph_builder.py @@ -21,10 +21,13 @@ from nncf.common.graph import INPUT_NOOP_METATYPES from nncf.common.graph import LayerName +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes +from nncf.common.graph.utils import get_concat_axis from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph_tracer import GraphTracer from nncf.torch.dynamic_graph.graph_tracer import ModelInputInfo from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.operator_metatypes import CatMetatype from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES @@ -90,4 +93,19 @@ def convert(dynamic_graph: DynamicGraph, input_infos: List[ModelInputInfo] = Non output_port_id=dynamic_graph_edge.output_port_id, dtype=dynamic_graph_edge.dtype ) + + for node in nncf_graph.get_all_nodes(): + if node.metatype is CatMetatype: + input_edges = nncf_graph.get_input_edges(node) + output_edges = nncf_graph.get_output_edges(node) + # Case of intermediate node + if input_edges and output_edges: + input_shapes = [edge.tensor_shape for edge in input_edges] + output_shapes = [edge.tensor_shape for edge in output_edges] + # Case node is stack + if len(input_shapes[0]) != len(output_shapes[0]): + continue + axis = get_concat_axis(input_shapes, output_shapes) + layer_attributes = MultipleInputLayerAttributes(axis) + node.layer_attributes = layer_attributes return nncf_graph diff --git a/nncf/torch/layers.py b/nncf/torch/layers.py index e8d47791bcf..d3cc9a2d704 100644 --- a/nncf/torch/layers.py +++ b/nncf/torch/layers.py @@ -202,6 +202,7 @@ def from_module(module): dict_update(nncf_embedding_bag.__dict__, module.__dict__) return nncf_embedding_bag + NNCF_MODULES_DICT = { NNCFConv1d: nn.Conv1d, NNCFConv2d: nn.Conv2d, diff --git a/nncf/torch/pruning/base_algo.py b/nncf/torch/pruning/base_algo.py index 69753190250..89dfe48c755 100644 --- a/nncf/torch/pruning/base_algo.py +++ b/nncf/torch/pruning/base_algo.py @@ -30,7 +30,7 @@ from nncf.torch.graph.transformations.commands import TransformationPriority from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.commands import PTInsertionCommand -from nncf.torch.pruning.export_helpers import PT_PRUNING_OPERATOR_METATYPES +from nncf.torch.pruning.operations import PT_PRUNING_OPERATOR_METATYPES from nncf.torch.pruning.filter_pruning.layers import apply_filter_binary_mask from nncf.common.pruning.clusterization import Clusterization from nncf.common.pruning.clusterization import Cluster diff --git a/nncf/torch/pruning/filter_pruning/algo.py b/nncf/torch/pruning/filter_pruning/algo.py index 5ff6cd2d435..d9e9ea324b6 100644 --- a/nncf/torch/pruning/filter_pruning/algo.py +++ b/nncf/torch/pruning/filter_pruning/algo.py @@ -63,8 +63,8 @@ from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.pruning.base_algo import BasePruningAlgoBuilder from nncf.torch.pruning.base_algo import BasePruningAlgoController -from nncf.torch.pruning.export_helpers import PTElementwise -from nncf.torch.pruning.export_helpers import PT_PRUNING_OPERATOR_METATYPES +from nncf.torch.pruning.operations import PTElementwisePruningOp +from nncf.torch.pruning.operations import PT_PRUNING_OPERATOR_METATYPES from nncf.torch.pruning.filter_pruning.functions import FILTER_IMPORTANCE_FUNCTIONS from nncf.torch.pruning.filter_pruning.functions import calculate_binary_mask from nncf.torch.pruning.filter_pruning.functions import tensor_l2_normalizer @@ -111,7 +111,7 @@ def get_op_types_of_pruned_modules(self) -> List[str]: return types def get_types_of_grouping_ops(self) -> List[str]: - return PTElementwise.get_all_op_aliases() + return PTElementwisePruningOp.get_all_op_aliases() @ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_filter_pruning') @@ -680,7 +680,8 @@ def _apply_binary_mask_to_module_weight_and_bias(module: torch.nn.Module, continue node_module = self.model.get_containing_module(node.node_name) if node.data['output_mask'] is not None and node_module not in pruned_node_modules: - _apply_binary_mask_to_module_weight_and_bias(node_module, node.data['output_mask'], node.node_name) + _apply_binary_mask_to_module_weight_and_bias(node_module, node.data['output_mask'].tensor, + node.node_name) pruned_node_modules.append(node_module) def prepare_for_export(self): diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/operations.py similarity index 61% rename from nncf/torch/pruning/export_helpers.py rename to nncf/torch/pruning/operations.py index 53b6feda802..779eb2b407c 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/operations.py @@ -10,21 +10,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Union -from typing import List import torch from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp -from nncf.common.pruning.utils import is_grouped_conv -from nncf.common.pruning.utils import get_sources_of_node from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry -from nncf.common.pruning.mask_propagation import identity_mask_propagation -from nncf.common.pruning.mask_propagation import get_input_masks from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm -from nncf.common.graph.layer_attributes import GroupNormLayerAttributes from nncf.torch.graph.operator_metatypes import ( AddMetatype, AvgPool2dMetatype, @@ -58,6 +50,18 @@ SubMetatype, TanhMetatype, ) +from nncf.common.pruning.operations import ( + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + GroupNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp +) from nncf.common.utils.logger import logger as nncf_logger from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT @@ -66,22 +70,9 @@ PT_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") -class PTDefaultMetaOp(DefaultMetaOp): - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - """ - Propagate mask through a node using masks of all inputs and pruning mask of current node (if any). - Should set the following attributes: - input_masks - list of masks of input nodes (None if there is no mask in some input); - output_mask - resulting mask of node operation. - - :param node: Node from NNCF graph to propagate mask through it. - :param graph: Graph of model to prune. - """ - raise NotImplementedError - +class PTPruner: @classmethod - def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: """ Prune node by input_masks (if masks is not none and operation support it). @@ -91,7 +82,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): """ @classmethod - def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: """ Prune node by output_mask (if mask is not none and operation support it). @@ -102,76 +93,28 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('model_input') -class PTInput(PTDefaultMetaOp): +class PTInputPruningOp(InputPruningOp, PTPruner): subtypes = [PTInputNoopMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['input_masks'] = [] - node.data['output_mask'] = None - @PT_PRUNING_OPERATOR_METATYPES.register('model_output') -class PTOutput(PTDefaultMetaOp): +class PTOutputPruningOp(OutputPruningOp, PTPruner): subtypes = [PTOutputNoopMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['input_masks'] = [] - node.data['output_mask'] = None - @PT_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class PTIdentityMaskForwardOps(PTDefaultMetaOp): +class PTIdentityMaskForwardPruningOp(IdentityMaskForwardPruningOp, PTPruner): subtypes = [HardTanhMetatype, TanhMetatype, RELUMetatype, PRELUMetatype, ELUMetatype, GELUMetatype, SigmoidMetatype, SoftmaxMetatype, AvgPool2dMetatype, MaxPool2dMetatype, DropoutMetatype, SILUMetatype] additional_types = ['h_sigmoid', 'h_swish', 'RELU'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @PT_PRUNING_OPERATOR_METATYPES.register('convolution') -class PTConvolution(PTDefaultMetaOp): +class PTConvolutionPruningOp(ConvolutionPruningOp, PTPruner): subtypes = [Conv1dMetatype, Conv2dMetatype, Conv3dMetatype] @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - - @classmethod - def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return @@ -200,7 +143,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): ' {}.'.format(node.data['key'], old_num_channels, new_num_channels)) @classmethod - def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: mask = node.data['output_mask'] if mask is None: return @@ -221,33 +164,11 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class PTTransposeConvolution(PTDefaultMetaOp): +class PTTransposeConvolutionPruningOp(TransposeConvolutionPruningOp, PTPruner): subtypes = [ConvTranspose2dMetatype, ConvTranspose3dMetatype] @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - - @classmethod - def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return @@ -263,7 +184,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): ' {}.'.format(node.data['key'], old_num_clannels, node_module.in_channels)) @classmethod - def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: output_mask = node.data['output_mask'] if output_mask is None: return @@ -290,19 +211,11 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class PTBatchNorm(PTDefaultMetaOp): +class PTBatchNormPruningOp(BatchNormPruningOp, PTPruner): subtypes = [BatchNormMetatype] @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - - @classmethod - def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return @@ -324,21 +237,11 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('group_norm') -class GroupNorm(PTDefaultMetaOp): +class PTGroupNormPruningOp(GroupNormPruningOp, PTPruner): subtypes = [GroupNormMetatype] @classmethod - def accept_pruned_input(cls, node: NNCFNode): - # For Instance Normalization - return isinstance(node.layer_attributes, GroupNormLayerAttributes) \ - and node.layer_attributes.num_groups == node.layer_attributes.num_channels - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - - @classmethod - def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return @@ -359,96 +262,12 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): ' {}.'.format(node.data['key'], old_num_clannels, new_num_channels)) -@PT_PRUNING_OPERATOR_METATYPES.register('concat') -class PTConcat(PTDefaultMetaOp): - subtypes = [CatMetatype] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: - """ - Return whether all input sources of node is convolutions or not. - - :param node: Node to determine it's sources. - :param graph: NNCF graph to work with. - :return: True If all input sources of node is convolutions. - """ - - for input_node in graph.get_previous_nodes(node): - # If input has mask -> it went from convolution (source of this node is a convolution) - if input_node.data.get('output_mask', None) is None: - continue - - source_nodes = get_sources_of_node(input_node, graph, PTConvolution.get_all_op_aliases() + - PTStopMaskForwardOps.get_all_op_aliases() + - PTInput.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] - if any(t in sources_types for t in PTStopMaskForwardOps.get_all_op_aliases()): - return False - return True - - @classmethod - def fill_input_masks(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[torch.Tensor], None]: - """ - Fill input masks with all None replaced by identity masks. - If all input masks is None return None. - - :param node: Node to determine it's sources. - :param graph: NNCF graph to work with. - :return: Filled input masks. - """ - input_edges = graph.get_input_edges(node) - previous_nodes = [edge.from_node for edge in input_edges] - input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] - - if all(mask is None for mask in input_masks): - return None - - device = [m for m in input_masks if m is not None][0].device - - filled_input_masks = [] - for i, mask in enumerate(input_masks): - if mask is None: - mask = torch.ones(input_edges[i].tensor_shape[1], device=device) - filled_input_masks.append(mask) - return filled_input_masks - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = None - output_mask = None - - if cls.check_concat(node, graph): - input_masks = cls.fill_input_masks(node, graph) - if input_masks: - output_mask = torch.cat(input_masks) - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - - @PT_PRUNING_OPERATOR_METATYPES.register('elementwise') -class PTElementwise(PTDefaultMetaOp): +class PTElementwisePruningOp(ElementwisePruningOp, PTPruner): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - - node.data['input_masks'] = input_masks - if input_masks[0] is not None: - assert all(torch.allclose(input_masks[0], mask) for mask in input_masks) - node.data['output_mask'] = input_masks[0] - - @classmethod - def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph) -> None: input_mask = node.data['input_masks'][0] if input_mask is None: return @@ -469,19 +288,13 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class PTStopMaskForwardOps(PTDefaultMetaOp): +class PTStopMaskForwardPruningOp(StopMaskForwardPruningOp, PTPruner): subtypes = [MeanMetatype, MaxMetatype, MinMetatype, LinearMetatype, MatMulMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - - node.data['input_masks'] = input_masks - node.data['output_mask'] = None +@PT_PRUNING_OPERATOR_METATYPES.register('concat') +class PTConcatPruningOp(ConcatPruningOp, PTPruner): + subtypes = [CatMetatype] class ModelPruner(MaskPropagationAlgorithm): diff --git a/nncf/torch/pruning/utils.py b/nncf/torch/pruning/utils.py index 9c09afdee7b..fd80b8a23e1 100644 --- a/nncf/torch/pruning/utils.py +++ b/nncf/torch/pruning/utils.py @@ -18,6 +18,7 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNodeName from nncf.torch.graph.graph import NNCFNode +from nncf.torch.tensor import PTNNCFTensor from nncf.torch.nncf_network import NNCFNetwork @@ -60,4 +61,4 @@ def init_output_masks_in_graph(graph: NNCFGraph, nodes: List): for minfo in nodes: mask = minfo.operand.binary_filter_pruning_mask nncf_node = graph.get_node_by_id(minfo.nncf_node_id) - nncf_node.data['output_mask'] = mask + nncf_node.data['output_mask'] = PTNNCFTensor(mask) diff --git a/nncf/torch/tensor.py b/nncf/torch/tensor.py new file mode 100644 index 00000000000..68528da2cee --- /dev/null +++ b/nncf/torch/tensor.py @@ -0,0 +1,57 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import torch + +from typing import List + +from nncf.common.tensor import NNCFTensor +from nncf.common.tensor import NNCFBaseTensorProcessor + + +class PTNNCFTensorProcessor(NNCFBaseTensorProcessor): + """ + A realization of the processing methods set for PTNNCFTensors. + """ + + @classmethod + def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor: + ret_tensor = torch.cat([t.tensor for t in tensors], dim=axis) + return PTNNCFTensor(ret_tensor) + + @classmethod + def ones(cls, shape: List[int], device) -> NNCFTensor: + return PTNNCFTensor(torch.ones(shape)) + + @classmethod + def check_all_close(cls, tensors: List[NNCFTensor]) -> None: + for input_mask in tensors[1:]: + assert torch.allclose(tensors[0].tensor, input_mask.tensor) + + +class PTNNCFTensor(NNCFTensor): + """ + A realisation of torch tensors wrapper for common NNCF algorithms. + """ + + def __init__(self, tensor: torch.tensor): + # In case somebody attempts to wrap + # tensor twice + if isinstance(tensor, self.__class__): + tensor = tensor.tensor + + super().__init__(tensor, PTNNCFTensorProcessor) + + @property + def device(self) -> torch.device: + return self._tensor.device diff --git a/tests/common/graph/test_utils.py b/tests/common/graph/test_utils.py new file mode 100644 index 00000000000..0f1b401cd76 --- /dev/null +++ b/tests/common/graph/test_utils.py @@ -0,0 +1,19 @@ +import pytest + +from nncf.common.graph.utils import get_concat_axis + + +TEST_CASES = [ + ([(1, 1), (1, 1)], [(2, 1)], [0]), + ([(None, 1, 1, 5)], [(None, 1, 1, 7)], [3, -1]), + ([(None, 1, 1, 5), (None, 1, 1, 5)], [(None, 1, 1, 10)], [3, -1]), + ([(1, 1, None), (1, 1, None)], [(1, 1, None)], [2, -1]), + ([(1, 1, 32, 1), (1, 1, 32, 1)], [(1, 1, 64, 1)], [2, -1]), + ([(1, 1, 5), (1, 1, 5)], [(1, 1, 5)], [-1]), +] + + +@pytest.mark.parametrize('input_shape,output_shape,possible_axes', TEST_CASES) +def test_get_concat_axis(input_shape, output_shape, possible_axes): + axis = get_concat_axis(input_shape, output_shape) + assert axis in possible_axes diff --git a/tests/common/pruning/dummy_types.py b/tests/common/pruning/dummy_types.py new file mode 100644 index 00000000000..131bca5a388 --- /dev/null +++ b/tests/common/pruning/dummy_types.py @@ -0,0 +1,120 @@ +from typing import List + +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.operations import ( + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + GroupNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp, +) + + +class DummyDefaultMetatype(OperatorMetatype): + name = None + + @classmethod + def get_all_aliases(cls) -> List[str]: + return [cls.name] + + +class DummyInputMetatype(OperatorMetatype): + name = 'input' + + +class DummyOutputMetatype(OperatorMetatype): + name = 'output' + + +class DymmyIdentityMaskForwardMetatype(OperatorMetatype): + name = 'identity_mask_forward' + + +class DummyElementwiseMetatype(OperatorMetatype): + name = 'elementwise' + + +class DummyConvMetatype(OperatorMetatype): + name = 'conv' + + +class DummyTransposeConvolutionMetatype(OperatorMetatype): + name = 'transpose_conv' + + +class DummyBatchNormMetatype(OperatorMetatype): + name = 'batch_norm' + + +class DummyGroupNormMetatype(OperatorMetatype): + name = 'group_norm' + + +class DummyConcatMetatype(OperatorMetatype): + name = 'concat' + + +class DummyStopPropoagtionMetatype(OperatorMetatype): + name = 'stop_propagation_ops' + + +DUMMY_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyInputMetatype.name) +class DummyInputPruningOp(InputPruningOp): + additional_types = [DummyInputMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyOutputMetatype.name) +class DummyOutputPruningOp(OutputPruningOp): + additional_types = [DummyOutputMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DymmyIdentityMaskForwardMetatype.name) +class DummyIdentityMaskForward(IdentityMaskForwardPruningOp): + additional_types = [DymmyIdentityMaskForwardMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyStopPropoagtionMetatype.name) +class DummyStopMaskForward(StopMaskForwardPruningOp): + additional_types = [DummyStopPropoagtionMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConvMetatype.name) +class DummyConvPruningOp(ConvolutionPruningOp): + additional_types = [DummyConvMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyTransposeConvolutionMetatype.name) +class DummyTransposeConvPruningOp(TransposeConvolutionPruningOp): + additional_types = [DummyTransposeConvolutionMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyBatchNormMetatype.name) +class DummyBatchNormPruningOp(BatchNormPruningOp): + additional_types = [DummyBatchNormMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyGroupNormMetatype.name) +class DummyGroupNormPruningOp(GroupNormPruningOp): + additional_types = [DummyGroupNormMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyElementwiseMetatype.name) +class DummyElementwisePruningOp(ElementwisePruningOp): + additional_types = [DummyElementwiseMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConcatMetatype.name) +class DummyConcatPruningOp(ConcatPruningOp): + ConvolutionOp = DummyConvPruningOp + StopMaskForwardOp = DummyStopMaskForward + InputOp = DummyInputPruningOp + additional_types = [DummyConcatMetatype.name] diff --git a/tests/common/pruning/tensor.py b/tests/common/pruning/tensor.py new file mode 100644 index 00000000000..c13df07c580 --- /dev/null +++ b/tests/common/pruning/tensor.py @@ -0,0 +1,49 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +from typing import List + +from nncf.common.tensor import NNCFTensor +from nncf.common.tensor import NNCFBaseTensorProcessor + + +class NPNNCFTensorProcessor(NNCFBaseTensorProcessor): + @classmethod + def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor: + ret_tensor = np.concatenate([t.tensor for t in tensors], axis=axis) + return NPNNCFTensor(ret_tensor) + + @classmethod + def ones(cls, shape: List[int], device) -> NNCFTensor: + return NPNNCFTensor(np.ones(shape)) + + @classmethod + def check_all_close(cls, tensors: List[NNCFTensor]) -> None: + for input_mask in tensors[1:]: + np.testing.assert_allclose(tensors[0].tensor, input_mask.tensor) + + +class NPNNCFTensor(NNCFTensor): + def __init__(self, tensor: np.array): + # In case somebody attempts to wrap + # tensor twice + if isinstance(tensor, self.__class__): + tensor = tensor.tensor + + super().__init__(tensor, NPNNCFTensorProcessor) + + @property + def device(self) -> None: + return None diff --git a/tests/common/pruning/test_pruning_operations.py b/tests/common/pruning/test_pruning_operations.py new file mode 100644 index 00000000000..1ab526f9292 --- /dev/null +++ b/tests/common/pruning/test_pruning_operations.py @@ -0,0 +1,263 @@ +import numpy as np +import pytest + +from functools import partial + + +from tests.common.pruning import dummy_types +from tests.common.pruning.tensor import NPNNCFTensor +from tests.common.pruning.tensor import NPNNCFTensorProcessor +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes +from nncf.common.graph.layer_attributes import GroupNormLayerAttributes +from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.pruning.operations import BasePruningOp +from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm + + +@pytest.mark.parametrize('dummy_op_class,accept_pruned_input', [(dummy_types.DummyInputPruningOp, False), + (dummy_types.DummyOutputPruningOp, True), + (dummy_types.DummyStopMaskForward, False)]) +def test_stop_propagate_ops(dummy_op_class, accept_pruned_input): + node = NNCFNode(0, 'dummy_node') + assert dummy_op_class.accept_pruned_input(node) == accept_pruned_input + dummy_op_class.mask_propagation(node, None) + assert node.data['output_mask'] is None + + +@pytest.mark.parametrize('dummy_op_class', [dummy_types.DummyIdentityMaskForward, dummy_types.DummyBatchNormPruningOp]) +def test_identity_mask_propogation_prune_ops(dummy_op_class): + assert dummy_op_class.accept_pruned_input(None) + graph = NNCFGraph() + conv_op = graph.add_nncf_node('conv_op', 'conv', dummy_types.DummyConvMetatype) + identity_ops = [] + for alias in dummy_op_class.get_all_op_aliases(): + identity_op = graph.add_nncf_node('identity', alias, dummy_types.DymmyIdentityMaskForwardMetatype) + graph.add_edge_between_nncf_nodes( + from_node_id=conv_op.node_id, + to_node_id=identity_op.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + identity_ops.append(identity_op) + # Check with and without masks + for output_mask in [None, NPNNCFTensor(np.ones((10,)))]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + for identity_op in identity_ops: + identity_op = graph.get_node_by_id(identity_op.node_id) + assert np.all(identity_op.data['output_mask'] == output_mask) + + +@pytest.mark.parametrize('valid_masks', [None, True, False]) +def test_elementwise_prune_ops(valid_masks): + graph = NNCFGraph() + conv_op_0 = graph.add_nncf_node('conv_op_0', dummy_types.DummyConvMetatype.name, dummy_types.DummyConvMetatype) + conv_op_1 = graph.add_nncf_node('conv_op_1', dummy_types.DummyConvMetatype.name, dummy_types.DummyConvMetatype) + elementwise_op = graph.add_nncf_node('elementwise', dummy_types.DummyElementwiseMetatype.name, + dummy_types.DummyElementwiseMetatype) + add_node = partial(graph.add_edge_between_nncf_nodes, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # conv_op_0 -> elementwise + add_node(from_node_id=conv_op_0.node_id, + to_node_id=elementwise_op.node_id) + + # conv_op_1 -> elementwise + add_node(from_node_id=conv_op_1.node_id, + to_node_id=elementwise_op.node_id) + + masks = [NPNNCFTensor(np.ones((10,))), NPNNCFTensor(np.ones((10,)))] if valid_masks is not None else [None, None] + + def set_masks(masks, ops): + for conv_op, mask in zip(ops, masks): + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = mask + + if valid_masks is None or valid_masks: + if valid_masks: + set_masks(masks, [conv_op_0, conv_op_1]) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + elementwise_op = graph.get_node_by_id(elementwise_op.node_id) + assert np.all(elementwise_op.data['output_mask'] == masks[0]) + else: + def check_wrong_masks(masks): + with pytest.raises(AssertionError): + set_masks(masks, [conv_op_0, conv_op_1]) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + + masks[0].tensor[0] = 0 + check_wrong_masks(masks) + masks[0] = NPNNCFTensorProcessor.concatenate([masks[1], NPNNCFTensor(np.array([1]))], axis=0) + check_wrong_masks(masks) + + +@pytest.mark.parametrize('num_channels,num_groups,accept_pruned_input_ref', [(10, 10, True), + (10, 5, False), + (10, 1, False)]) +def test_group_norm_pruning_ops(num_channels, num_groups, accept_pruned_input_ref): + graph = NNCFGraph() + conv_op = graph.add_nncf_node('conv_op', 'conv', dummy_types.DummyConvMetatype) + group_norm_layer_attributes = GroupNormLayerAttributes(True, num_channels=num_channels, + num_groups=num_groups) + group_norm_op = graph.add_nncf_node('identity', dummy_types.DummyGroupNormMetatype.name, + dummy_types.DummyGroupNormMetatype, + layer_attributes=group_norm_layer_attributes) + assert dummy_types.DummyGroupNormPruningOp.accept_pruned_input(group_norm_op) == accept_pruned_input_ref + graph.add_edge_between_nncf_nodes( + from_node_id=conv_op.node_id, + to_node_id=group_norm_op.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # Check with and without masks + for output_mask in [None, NPNNCFTensor(np.ones((10,)))]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + identity_op = graph.get_node_by_id(group_norm_op.node_id) + if not accept_pruned_input_ref: + output_mask = None + + assert np.all(identity_op.data['output_mask'] == output_mask) + + +class DummyMaskProducerMetatype(dummy_types.DummyDefaultMetatype): + name = 'mask_producer' + + +@dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyMaskProducerMetatype.name) +class MockOpMaskProducer(BasePruningOp): + additional_types = [DummyMaskProducerMetatype.name] + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + pass + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + pass + + +@pytest.mark.parametrize('transpose', [True, False], ids=['transpose', 'not_transpose']) +@pytest.mark.parametrize('layer_attributes,ref_accept_pruned_input,conv_type', [ + ({'in_channels': 5, + 'out_channels': 10, + 'groups': 1}, True, 'usual_conv'), + ({'in_channels': 10, + 'out_channels': 20, + 'groups': 5}, False, 'grouped_conv_no_depthwise'), + ({'in_channels': 10, + 'out_channels': 20, + 'groups': 10}, True, 'depthwise_conv') +], + ids=['usual_conv', + 'grouped_conv_no_depthwise', + 'depthwise_conv'] +) +def test_conv_pruning_ops(transpose, layer_attributes, ref_accept_pruned_input, conv_type): + default_conv_params = { + 'weight_requires_grad': True, + 'kernel_size': (2, 2), + 'stride': (1, 1), + 'padding_values': [0, 0] + } + graph = NNCFGraph() + dummy_op_before = graph.add_nncf_node('dummy_op_before', DummyMaskProducerMetatype.name, + DummyMaskProducerMetatype) + target_conv_attributes = ConvolutionLayerAttributes(transpose=transpose, **layer_attributes, **default_conv_params) + conv_op_target = graph.add_nncf_node('conv_op_target', dummy_types.DummyConvMetatype.name, + dummy_types.DummyConvMetatype, + layer_attributes=target_conv_attributes) + graph.add_edge_between_nncf_nodes(from_node_id=dummy_op_before.node_id, + to_node_id=conv_op_target.node_id, + tensor_shape=[layer_attributes['in_channels']] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + pruning_op_class = dummy_types.DummyTransposeConvPruningOp if transpose else dummy_types.DummyConvPruningOp + assert pruning_op_class.accept_pruned_input(conv_op_target) == ref_accept_pruned_input + ones_input_mask = NPNNCFTensor(np.ones((layer_attributes['in_channels'],))) + ones_output_mask = NPNNCFTensor(np.ones((layer_attributes['out_channels'],))) + # Check all combinations of masks + for input_mask in [None, ones_input_mask]: + for output_mask in [None, ones_output_mask]: + dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id) + conv_op_target = graph.get_node_by_id(conv_op_target.node_id) + dummy_op_before.data['output_mask'] = input_mask + conv_op_target.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id) + conv_op_target = graph.get_node_by_id(conv_op_target.node_id) + if conv_type == 'usual_conv': + assert np.all(conv_op_target.data['output_mask'] == output_mask) + elif conv_type == 'grouped_conv_no_depthwise': + assert conv_op_target.data['output_mask'] is None + else: + assert np.all(conv_op_target.data['output_mask'] == input_mask) + + +@pytest.mark.parametrize('empty_mask_left_branch', [False, True]) +@pytest.mark.parametrize('empty_mask_right_branch', [False, True]) +@pytest.mark.parametrize('right_branch_output_channels', [5, 10]) +def test_convs_elementwise_source_before_concat(empty_mask_right_branch, empty_mask_left_branch, + right_branch_output_channels): + graph = NNCFGraph() + conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv', dummy_types.DummyConvMetatype) + conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv', dummy_types.DummyConvMetatype) + conv_op_2 = graph.add_nncf_node('conv_op_2', 'conv', dummy_types.DummyConvMetatype) + elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', dummy_types.DummyElementwiseMetatype) + concat_layer_attributes = MultipleInputLayerAttributes(2) + concat_node = graph.add_nncf_node('concat_node', 'concat', dummy_types.DummyConcatMetatype, + layer_attributes=concat_layer_attributes) + add_node = partial(graph.add_edge_between_nncf_nodes, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + + # conv_op_0 -> elementwise_node + add_node(from_node_id=conv_op_0.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4) + + # conv_op_1 -> elementwise_node + add_node(from_node_id=conv_op_1.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4) + + # elementwise_node -> concat_node + add_node(from_node_id=elementwise_node.node_id, + to_node_id=concat_node.node_id, + tensor_shape=[10] * 4) + + # conv_op_2 -> concat_node + add_node(from_node_id=conv_op_2.node_id, + to_node_id=concat_node.node_id, + tensor_shape=[10, 10, right_branch_output_channels, 10]) + + # Set masks + if not empty_mask_left_branch: + for conv_op in [conv_op_0, conv_op_1]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = NPNNCFTensor(np.ones(10)) + + if not empty_mask_right_branch: + conv_op = graph.get_node_by_id(conv_op_2.node_id) + conv_op.data['output_mask'] = NPNNCFTensor(np.ones(right_branch_output_channels)) + + # Propagate masks + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + # Check with masks + concat_node = graph.get_node_by_id(concat_node.node_id) + if empty_mask_left_branch and empty_mask_right_branch: + assert concat_node.data['output_mask'] is None + else: + reference_mask = np.ones((10 + right_branch_output_channels,)) + np.testing.assert_equal(concat_node.data['output_mask'].tensor, reference_mask) diff --git a/tests/tensorflow/test_model_converter.py b/tests/tensorflow/test_model_converter.py index e217b2df435..6402280ff9c 100644 --- a/tests/tensorflow/test_model_converter.py +++ b/tests/tensorflow/test_model_converter.py @@ -11,7 +11,9 @@ limitations under the License. """ +import pytest import tensorflow as tf +from functools import partial from tensorflow.python.keras import backend from tensorflow.python.keras import layers from tensorflow.python.keras import models @@ -19,11 +21,14 @@ from nncf.common.graph import INPUT_NOOP_METATYPES from nncf.common.graph import OUTPUT_NOOP_METATYPES +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.tensorflow.graph.converter import TFModelConverter from nncf.tensorflow.graph.converter import convert_keras_model_to_nncf_graph +from nncf.tensorflow.graph.metatypes.common import LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS from tests.tensorflow.helpers import get_basic_conv_test_model from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.quantization.test_algorithm_quantization import get_basic_quantization_config +from tests.tensorflow.pruning.helpers import get_concat_test_model def test_struct_auxiliary_nodes_nncf_graph(): @@ -84,3 +89,36 @@ def test_get_custom_layers(): assert len(custom_layers) == 1 assert CustomLayerForTest.CUSTOM_LAYER_NAME in custom_layers assert isinstance(custom_layers[CustomLayerForTest.CUSTOM_LAYER_NAME], CustomLayerForTest) + + +def get_model_with_reshapes_and_concats(batch_size=None): + inputs = layers.Input((64, ), batch_size=batch_size) + x = tf.reshape(inputs, (32, -1)) + x = layers.Reshape((16, -1))(x) + ones = tf.ones_like(x) + t1 = layers.concatenate([x, ones]) + # pylint: disable=E1120,E1123 + t2 = tf.concat([x, ones], axis=-1) + y = tf.concat([t1, t2], axis=-1) + return models.Model(inputs, y, name='ModelWithReshape') + + +CONCAT_MODELS = [partial(get_concat_test_model, input_shape=[1, 8, 8, 1]), + get_model_with_reshapes_and_concats] +REF_CONCAT_ATTRS = [{'tf_op_layer_tf_concat_1': {'axis': [-1, 3]}, + 'tf_op_layer_tf_concat_2': {'axis': [-1, 3]}}, + {'concatenate': {'axis': [-1, 2]}, + 'tf_op_layer_concat': {'axis': [-1, 2]}, + 'tf_op_layer_concat_1': {'axis': [-1, 2]}}] + + +@pytest.mark.parametrize('model, ref_attrs', list(zip(CONCAT_MODELS, REF_CONCAT_ATTRS))) +def test_model_with_reshape_and_concat(model, ref_attrs): + model = model() + graph = convert_keras_model_to_nncf_graph(model) + for node in graph.get_all_nodes(): + if node.metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + assert node.node_name in ref_attrs + assert node.layer_attributes is not None + assert isinstance(node.layer_attributes, MultipleInputLayerAttributes) + assert node.layer_attributes.axis in ref_attrs[node.node_name]['axis'] diff --git a/tests/torch/pruning/filter_pruning/test_algo.py b/tests/torch/pruning/filter_pruning/test_algo.py index edbee7afb7d..6d550d72423 100644 --- a/tests/torch/pruning/filter_pruning/test_algo.py +++ b/tests/torch/pruning/filter_pruning/test_algo.py @@ -267,7 +267,7 @@ def test_applying_masks_for_bn_after_concat(prune_bn): ] graph = pruned_model.get_original_graph() for i, node in enumerate(graph.get_nodes_by_types(['cat'])): - assert np.allclose(node.data['output_mask'].numpy(), ref_concat_masks[i]) + assert np.allclose(node.data['output_mask'].tensor.numpy(), ref_concat_masks[i]) @pytest.mark.parametrize('zero_grad', @@ -405,5 +405,5 @@ def test_disconnected_graph(): conv1 = graph.get_node_by_name('DisconectedGraphModel/NNCFConv2d[conv1]/conv2d_0') conv2 = graph.get_node_by_name('DisconectedGraphModel/NNCFConv2d[conv2]/conv2d_0') - assert sum(conv1.data['output_mask']) == 8 - assert sum(conv2.data['output_mask']) == 8 + assert sum(conv1.data['output_mask'].tensor) == 8 + assert sum(conv2.data['output_mask'].tensor) == 8 diff --git a/tests/torch/pruning/test_model_pruning_analysis.py b/tests/torch/pruning/test_model_pruning_analysis.py index a6a42f68277..a1b9d21ec3e 100644 --- a/tests/torch/pruning/test_model_pruning_analysis.py +++ b/tests/torch/pruning/test_model_pruning_analysis.py @@ -30,9 +30,9 @@ from nncf.torch.dynamic_graph.graph_tracer import ModelInputInfo from nncf.torch.layers import NNCF_PRUNING_MODULES_DICT from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.pruning.export_helpers import PTElementwise -from nncf.torch.pruning.export_helpers import PTIdentityMaskForwardOps -from nncf.torch.pruning.export_helpers import PT_PRUNING_OPERATOR_METATYPES +from nncf.torch.pruning.operations import PTElementwisePruningOp +from nncf.torch.pruning.operations import PTIdentityMaskForwardPruningOp +from nncf.torch.pruning.operations import PT_PRUNING_OPERATOR_METATYPES from nncf.common.pruning.utils import is_depthwise_conv from nncf.torch.pruning.filter_pruning.algo import FilterPruningBuilder from tests.torch.helpers import create_compressed_model_and_algo_for_test @@ -195,7 +195,7 @@ def test_pruning_node_selector(test_input_info_struct_: GroupPruningModulesTestS prune_first, prune_last, prune_downsample = test_input_info_struct_.prune_params pruning_operations = [v.op_func_name for v in NNCF_PRUNING_MODULES_DICT] - grouping_operations = PTElementwise.get_all_op_aliases() + grouping_operations = PTElementwisePruningOp.get_all_op_aliases() from nncf.common.pruning.pruning_node_selector import PruningNodeSelector pruning_node_selector = PruningNodeSelector(PT_PRUNING_OPERATOR_METATYPES, pruning_operations, @@ -264,7 +264,7 @@ def test_group_special_nodes(test_special_ops_struct: GroupSpecialModulesTestStr special_ops_clusterization = cluster_special_ops(nncf_model.get_original_graph(), algo_builder.get_types_of_grouping_ops(), - PTIdentityMaskForwardOps.get_all_op_aliases()) + PTIdentityMaskForwardPruningOp.get_all_op_aliases()) for ref_cluster in test_special_ops_struct.eltwise_clusters: cluster = special_ops_clusterization.get_cluster_containing_element(ref_cluster[0]) diff --git a/tests/torch/test_graph_building.py b/tests/torch/test_graph_building.py index 79db29de4fe..9f4dbf78ae1 100644 --- a/tests/torch/test_graph_building.py +++ b/tests/torch/test_graph_building.py @@ -16,6 +16,7 @@ from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from typing import List from typing import Tuple @@ -34,6 +35,7 @@ from nncf.torch.dynamic_graph.context import no_nncf_trace from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.graph.graph_builder import GraphBuilder +from nncf.torch.graph.operator_metatypes import CatMetatype from tests.torch.helpers import create_compressed_model_and_algo_for_test from tests.torch.helpers import register_bn_adaptation_init_args from tests.torch.test_compressed_graph import get_basic_quantization_config @@ -202,6 +204,54 @@ def test_activation_shape_tracing(input_shape: Tuple): assert output_tensor_shapes == ref_output_shapes, "Failed for node ID: {}".format(node_id) +class ModelForTestWithReshapeFlattenAndConcat(ModelForTest): + def forward(self, x): + y = super().forward(x) + size = y.size() + y = y.view(size + (1, 1)) + + y_copy = torch.ones_like(y) + y = torch.stack([y, y_copy]) + + y_copy = torch.ones_like(y) + y = torch.cat([y, y_copy], -1) + + y = torch.flatten(y) + _ = y.view(-1) + + y_copy = torch.ones_like(y) + y = torch.stack([y, y_copy]) + + y_copy = torch.ones_like(y) + y = torch.cat([y, y_copy], -1) + return y + + +@pytest.mark.parametrize("input_shape", input_shapes) +def test_concat_attributes_saved_during_graph_building(input_shape): + model = ModelForTestWithReshapeFlattenAndConcat() + input_info = ModelInputInfo(input_shape) + graph_builder = GraphBuilder(create_dummy_forward_fn([input_info, ], with_input_tracing=True, + with_output_tracing=True)) + graph = graph_builder.build_graph(model) + cat_nodes_with_attributes = { + 'ModelForTestWithReshapeFlattenAndConcat/cat_0': {'axis': 1}, + 'ModelForTestWithReshapeFlattenAndConcat/cat_1': {'axis': 6}, + 'ModelForTestWithReshapeFlattenAndConcat/cat_2': {'axis': 1}, + 'ModelForTestWithReshapeFlattenAndConcat/stack_0': None, + 'ModelForTestWithReshapeFlattenAndConcat/stack_1': None + } + + for node in graph.get_all_nodes(): + if node.metatype is CatMetatype: + assert node.node_name in cat_nodes_with_attributes + if isinstance(node.layer_attributes, MultipleInputLayerAttributes): + assert node.layer_attributes.axis == cat_nodes_with_attributes[node.node_name]['axis'] + else: + assert node.layer_attributes is None + assert cat_nodes_with_attributes[node.node_name] is None + + TEST_KEYWORD_1 = "keyword1" TEST_KEYWORD_2 = "keyword2" INPUT_INFO_CONFIG_VS_FORWARD_ARGS = [