diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index f4938df20de..891b50d6c8e 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -64,6 +64,10 @@ def layer_name(self) -> LayerName: def layer_attributes(self) -> BaseLayerAttributes: return self.data.get(NNCFGraph.LAYER_ATTRIBUTES) + @layer_attributes.setter + def layer_attributes(self, data) -> None: + self.data[NNCFGraph.LAYER_ATTRIBUTES] = data + @property def ignored_algorithms(self) -> List[str]: return self.data.get(NNCFGraph.IGNORED_ALGOS_ATTR, []) diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index a46be241d41..f839a2946a2 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -29,6 +29,19 @@ class BaseLayerAttributes(ABC): """ +class MultipleInputLayerAttributes(BaseLayerAttributes): + """ + Represents a layer with multiple inputs. + """ + def __init__(self, + axis: int): + self.axis = axis + + def __eq__(self, other): + return isinstance(other, MultipleInputLayerAttributes) \ + and self.axis == other.axis + + class WeightedLayerAttributes(BaseLayerAttributes): """ Represents a layer with weights. diff --git a/nncf/common/graph/utils.py b/nncf/common/graph/utils.py new file mode 100644 index 00000000000..23a8cee862a --- /dev/null +++ b/nncf/common/graph/utils.py @@ -0,0 +1,43 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import List + +from nncf.common.utils.logger import logger + + +def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int]]) -> int: + """ + Returns concatenation axis by given input and output shape of concat node. + + :param input_shapes: Input_shapes of given concat node. + :param output_shapes: Input_shapes of given concat node. + :returns: Concatenation axis of given concat node. + """ + axis = None + none_dim = None + for idx, (dim_in, dim_out) in enumerate(zip(input_shapes[0], output_shapes[0])): + if dim_in != dim_out: + axis = idx + break + if dim_in is None: + none_dim = idx + + if axis is None: + if none_dim is None: + axis = -1 + logger.warning('Identity concat node detected') + else: + axis = none_dim + + return axis diff --git a/nncf/common/pruning/default_pruning_op.py b/nncf/common/pruning/default_pruning_op.py new file mode 100644 index 00000000000..b0949771c36 --- /dev/null +++ b/nncf/common/pruning/default_pruning_op.py @@ -0,0 +1,53 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.graph import NNCFGraph + + +class DefaultPruningOp: + """ + Determines meta operations which aggregate operations having common + properties of interaction with pruning masks + """ + + subtypes = [] + additional_types = [] + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + """ + :return: accept_pruned_input - can this operation work with pruned input or not + """ + raise NotImplementedError + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + """ + Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any). + + :param node: The graph node to propagate mask through it + :param graph: The model graph to prune + """ + raise NotImplementedError + + @classmethod + def get_all_op_aliases(cls): + """ + :return: list of all aliases of types in metatype + """ + op_types = [] + for subtype in cls.subtypes: + op_types.extend(subtype.get_all_aliases()) + op_types = list(set(op_types)) + cls.additional_types + return op_types diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index 1187def3e77..8dcaab9e2b3 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -10,43 +10,244 @@ See the License for the specific language governing permissions and limitations under the License. """ + +import numpy as np + +from typing import Union, List + from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode +from nncf.common.pruning.utils import is_grouped_conv +from nncf.common.pruning.utils import get_sources_of_node +from nncf.common.pruning.utils import is_depthwise_conv +from nncf.common.graph.layer_attributes import GroupNormLayerAttributes +from nncf.common.pruning.mask_propagation import identity_mask_propagation +from nncf.common.pruning.mask_propagation import get_input_masks +from nncf.common.pruning.default_pruning_op import DefaultPruningOp -class DefaultMetaOp: - """ - Determines meta operations which aggregate operations having common - properties of interaction with pruning masks - """ +class InputPruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return False + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + node.data['output_mask'] = None - subtypes = [] - additional_types = [] +class OutputPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): - """ - :return: accept_pruned_input - can this operation work with pruned input or not - """ - raise NotImplementedError + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + node.data['output_mask'] = None + + +class IdentityMaskForwardPruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + identity_mask_propagation(node, graph) + + +class ConvolutionPruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + accept_pruned_input = True + if is_grouped_conv(node): + if not is_depthwise_conv(node): + accept_pruned_input = False + return accept_pruned_input + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + input_masks = get_input_masks(node, graph) + output_mask = node.data.get('output_mask', None) + + if is_grouped_conv(node): + output_mask = None + if is_depthwise_conv(node): + output_mask = input_masks[0] + + node.data['output_mask'] = output_mask + + +class TransposeConvolutionPruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + accept_pruned_input = True + if is_grouped_conv(node): + if not is_depthwise_conv(node): + accept_pruned_input = False + return accept_pruned_input + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + input_masks = get_input_masks(node, graph) + output_mask = node.data.get('output_mask', None) + + # In case of group convs we can't prune by output filters + if is_grouped_conv(node): + output_mask = None + if is_depthwise_conv(node): + output_mask = input_masks[0] + + node.data['output_mask'] = output_mask + + +class BatchNormPruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + identity_mask_propagation(node, graph) + + +class GroupNormPruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + # For Instance Normalization + return isinstance(node.layer_attributes, GroupNormLayerAttributes) \ + and node.layer_attributes.num_groups == node.layer_attributes.num_channels @classmethod def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + identity_mask_propagation(node, graph) + + +class ConcatPruningOp(DefaultPruningOp): + ConvolutionOp = None # type: ConvolutionPruningOp + StopMaskForwardOp = None # type: StopMaskForwardPruningOp + InputOp = None # type: InputPruningOp + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: """ - Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any). + Return whether all input sources of node is convolutions or not. - :param node: The graph node to propagate mask through it - :param graph: The model graph to prune + :param node: Node to determine it's sources + :param graph: NNCF graph to work with + :return: True if all input sources of node is convolutions """ - raise NotImplementedError + + for input_node in graph.get_previous_nodes(node): + # If input has mask -> it went from convolution (source of this node is a convolution) + node_has_mask = False + if input_node.data.get('output_mask', None) is not None: + node_has_mask = True + + source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() + + cls.StopMaskForwardOp.get_all_op_aliases() + + cls.InputOp.get_all_op_aliases()) + + source_types_old = [node.node_type for node in source_nodes] + sources_types_new = source_types_old + [input_node.node_type] + + decision_old_on_sources = any(t in source_types_old for t in cls.StopMaskForwardOp.get_all_op_aliases()) + decision_old = decision_old_on_sources and node_has_mask + + decision_new_on_sources = any(t in sources_types_new for t in cls.StopMaskForwardOp.get_all_op_aliases()) + decision_new = decision_new_on_sources and not node_has_mask + + if decision_new != decision_old: + is_on_sources_equal = decision_new_on_sources == decision_old_on_sources + if not is_on_sources_equal: + raise ValueError('ALERT') + + print(f'is_on_sources_equal = {is_on_sources_equal}') + print('behaviour changed!!!') + if decision_old: + return False + return True + + @classmethod + def _get_unit_mask(cls, dim, device): + return np.ones(dim) @classmethod - def get_all_op_aliases(cls): + def _get_masks_device(cls, input_masks): + return None + + @classmethod + def _concat_masks(cls, filled_input_masks): + return np.concatenate(filled_input_masks, 0) + + @classmethod + def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[np.array], None]: """ - :return: list of all aliases of types in metatype + Generate output mask from input masks with all None replaced by identity masks. + If all input masks is None return None. + + :param node: Node to determine it's sources. + :param graph: NNCF graph to work with. + :return: Filled input masks. """ - op_types = [] - for subtype in cls.subtypes: - op_types.extend(subtype.get_all_aliases()) - op_types = list(set(op_types)) + cls.additional_types - return op_types + input_edges = graph.get_input_edges(node) + previous_nodes = [edge.from_node for edge in input_edges] + input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] + + if all(mask is None for mask in input_masks): + return None + + filled_input_masks = [] + for i, mask in enumerate(input_masks): + if mask is None: + concat_axis = node.layer_attributes.axis + concat_dim = input_edges[i].tensor_shape[concat_axis] + device = cls._get_masks_device(input_masks) + mask = cls._get_unit_mask(concat_dim, device) + filled_input_masks.append(mask) + result_mask = cls._concat_masks(filled_input_masks) + return result_mask + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + result_mask = None + + assert cls.check_concat(node, graph) + if cls.check_concat(node, graph): + result_mask = cls.generate_output_mask(node, graph) + + node.data['output_mask'] = result_mask + + +class ElementwisePruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def _assert_input_masks_close(cls, input_masks): + for input_mask in input_masks[1:]: + np.testing.assert_allclose(input_masks[0], input_mask) + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + input_masks = get_input_masks(node, graph) + + node.data['input_masks'] = input_masks + if input_masks[0] is not None: + cls._assert_input_masks_close(input_masks) + node.data['output_mask'] = input_masks[0] + + +class StopMaskForwardPruningOp(DefaultPruningOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return False + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + node.data['output_mask'] = None diff --git a/nncf/common/pruning/mask_propagation.py b/nncf/common/pruning/mask_propagation.py index aa1a13fe879..3c44e18ec13 100644 --- a/nncf/common/pruning/mask_propagation.py +++ b/nncf/common/pruning/mask_propagation.py @@ -16,8 +16,9 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.default_pruning_op import DefaultPruningOp + TensorType = TypeVar('TensorType') @@ -40,7 +41,7 @@ def __init__(self, graph: NNCFGraph, pruning_operator_metatypes: PruningOperatio self._graph = graph self._pruning_operator_metatypes = pruning_operator_metatypes - def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp: + def get_meta_operation_by_type_name(self, type_name: str) -> DefaultPruningOp: """ Returns class of metaop that corresponds to `type_name` type. diff --git a/nncf/common/pruning/model_analysis.py b/nncf/common/pruning/model_analysis.py index 4df7f4870a4..2eeb3c309f2 100644 --- a/nncf/common/pruning/model_analysis.py +++ b/nncf/common/pruning/model_analysis.py @@ -17,7 +17,7 @@ from nncf.common.graph import NNCFNode from nncf.common.pruning.clusterization import Cluster from nncf.common.pruning.clusterization import Clusterization -from nncf.common.pruning.export_helpers import DefaultMetaOp +from nncf.common.pruning.export_helpers import DefaultPruningOp from nncf.common.pruning.utils import find_next_nodes_not_of_types from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry @@ -143,7 +143,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool: """ return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases() - def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp: + def get_meta_operation_by_type_name(self, type_name: str) -> DefaultPruningOp: """ Returns class of metaop that corresponds to `type_name` type. diff --git a/nncf/tensorflow/graph/converter.py b/nncf/tensorflow/graph/converter.py index 4a57f09abb3..66d84bc41c7 100644 --- a/nncf/tensorflow/graph/converter.py +++ b/nncf/tensorflow/graph/converter.py @@ -27,10 +27,13 @@ from nncf.common.graph import NNCFNodeName from nncf.common.graph import OperatorMetatype from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.utils import get_concat_axis from nncf.common.utils.logger import logger as nncf_logger from nncf.tensorflow.graph.metatypes.common import DECONV_LAYER_METATYPES from nncf.tensorflow.graph.metatypes.common import DEPTHWISE_CONV_LAYER_METATYPES +from nncf.tensorflow.graph.metatypes.common import LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS from nncf.tensorflow.graph.metatypes.common import GENERAL_CONV_LAYER_METATYPES from nncf.tensorflow.graph.metatypes.matcher import get_keras_layer_metatype from nncf.tensorflow.graph.metatypes.matcher import get_op_metatype @@ -513,6 +516,8 @@ def convert(self) -> NNCFGraph: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=True) elif metatype in GENERAL_CONV_LAYER_METATYPES: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=False) + elif metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + layer_attributes = _get_multiple_input_layer_attributes(layer) else: layer_attributes = None is_shared = len(self._layer_name_to_node_names[layer_name]) > 1 @@ -603,6 +608,8 @@ def convert(self) -> NNCFGraph: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=True) elif layer_metatype in GENERAL_CONV_LAYER_METATYPES: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=False) + elif layer_metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + layer_attributes = _get_multiple_input_layer_attributes(model_layer) if layer_attributes is not None: attrs.update({NNCFGraph.LAYER_ATTRIBUTES: layer_attributes}) @@ -650,6 +657,16 @@ def convert(self) -> NNCFGraph: return nncf_graph +def _get_multiple_input_layer_attributes(layer: tf.keras.layers.Layer) -> MultipleInputLayerAttributes: + if hasattr(layer, 'axis'): + axis = layer.axis + else: + input_shape = layer.input_shape + output_shape = layer.output_shape + axis = get_concat_axis(input_shape, output_shape) + return MultipleInputLayerAttributes(axis) + + def _get_conv_layer_attributes(layer: tf.keras.layers.Layer, is_depthwise: bool = False) -> ConvolutionLayerAttributes: channel_axis = get_input_channel_axis(layer) layer_ = unwrap_layer(layer) diff --git a/nncf/tensorflow/pruning/base_algorithm.py b/nncf/tensorflow/pruning/base_algorithm.py index 2e354cb28da..d76df6ff65f 100644 --- a/nncf/tensorflow/pruning/base_algorithm.py +++ b/nncf/tensorflow/pruning/base_algorithm.py @@ -40,8 +40,8 @@ from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout from nncf.tensorflow.graph.utils import get_layer_identifier from nncf.tensorflow.graph.utils import collect_wrapped_layers -from nncf.tensorflow.pruning.export_helpers import TFElementwise -from nncf.tensorflow.pruning.export_helpers import TFIdentityMaskForwardOps +from nncf.tensorflow.pruning.export_helpers import TFElementwisePruningOp +from nncf.tensorflow.pruning.export_helpers import TFIdentityMaskForwardPruningOp from nncf.tensorflow.pruning.export_helpers import TF_PRUNING_OPERATOR_METATYPES from nncf.tensorflow.pruning.utils import get_filter_axis from nncf.tensorflow.pruning.utils import get_filters_num @@ -207,7 +207,7 @@ def _get_insertion_command_binary_mask(self, layer_name: str, @staticmethod def _get_bn_for_node(node: NNCFNode, bn_nodes: List[NNCFNode]) -> Tuple[bool, List[NNCFNode]]: is_finished = False - propagating_ops = [op_name for meta_op in [TFIdentityMaskForwardOps, TFElementwise] + propagating_ops = [op_name for meta_op in [TFIdentityMaskForwardPruningOp, TFElementwisePruningOp] for op_name in meta_op.get_all_op_aliases()] if node.node_type == 'BatchNormalization': is_finished = True diff --git a/nncf/tensorflow/pruning/export_helpers.py b/nncf/tensorflow/pruning/export_helpers.py index fac53e8fbda..6dde4065bfa 100644 --- a/nncf/tensorflow/pruning/export_helpers.py +++ b/nncf/tensorflow/pruning/export_helpers.py @@ -12,23 +12,25 @@ """ from typing import Dict from typing import List -from typing import Union import tensorflow as tf -from nncf.common.pruning.utils import is_depthwise_conv from nncf.tensorflow.graph.pattern_operations import KERAS_ACTIVATIONS_OPERATIONS from nncf.tensorflow.graph.pattern_operations import ELEMENTWISE_OPERATIONS from nncf.tensorflow.graph.pattern_operations import TF_ACTIVATIONS_OPERATIONS from nncf.common.graph.definitions import NNCFGraphNodeType -from nncf.common.graph import NNCFGraph -from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp -from nncf.common.pruning.mask_propagation import identity_mask_propagation -from nncf.common.pruning.mask_propagation import get_input_masks -from nncf.common.pruning.utils import get_sources_of_node -from nncf.common.pruning.utils import is_grouped_conv from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.export_helpers import ( + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp +) TF_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") @@ -38,203 +40,71 @@ def _get_types(operations_dict: Dict) -> List[str]: @TF_PRUNING_OPERATOR_METATYPES.register('model_input') -class TFInput(DefaultMetaOp): +class TFInputPruningOp(InputPruningOp): additional_types = ['InputLayer', NNCFGraphNodeType.INPUT_NODE] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None @TF_PRUNING_OPERATOR_METATYPES.register('model_output') -class TFOutput(DefaultMetaOp): +class TFOutputPruningOp(OutputPruningOp): additional_types = [NNCFGraphNodeType.OUTPUT_NODE] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None @TF_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class TFIdentityMaskForwardOps(DefaultMetaOp): +class TFIdentityMaskForwardPruningOp(IdentityMaskForwardPruningOp): additional_types = _get_types(KERAS_ACTIVATIONS_OPERATIONS) + _get_types(TF_ACTIVATIONS_OPERATIONS) \ + ['AvgPool2D', 'GlobalAvgPool2D', 'AveragePooling2D', 'GlobalAveragePooling2D'] \ + ['MaxPooling2D', 'GlobalMaxPooling2D', 'MaxPool2D', 'GlobalMaxPool2D'] \ + ['Dropout', 'ZeroPadding2D', 'Identity', 'Pad', 'UpSampling2D'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @TF_PRUNING_OPERATOR_METATYPES.register('convolution') -class TFConvolution(DefaultMetaOp): +class TFConvolutionPruningOp(ConvolutionPruningOp): additional_types = ['Conv1D', 'Conv2D', 'Conv3D', 'DepthwiseConv2D'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['output_mask'] = output_mask - @TF_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class TFTransposeConvolution(DefaultMetaOp): +class TFTransposeConvolutionPruningOp(TransposeConvolutionPruningOp): additional_types = ['Conv1DTranspose', 'Conv2DTranspose', 'Conv3DTranspose'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['output_mask'] = output_mask - @TF_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class TFBatchNorm(DefaultMetaOp): +class TFBatchNormPruningOp(BatchNormPruningOp): additional_types = ['BatchNormalization', 'SyncBatchNormalization'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - - -@TF_PRUNING_OPERATOR_METATYPES.register('concat') -class TFConcat(DefaultMetaOp): - additional_types = ['Concatenate', 'ConcatV2'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: - """ - Return whether all input sources of node is convolutions or not. - - :param node: Node to determine it's sources - :param graph: NNCF graph to work with - :return: True if all input sources of node is convolutions - """ - - for input_node in graph.get_previous_nodes(node): - # If input has mask -> it went from convolution (source of this node is a convolution) - if input_node.data.get('output_mask', None) is None: - continue - - source_nodes = get_sources_of_node(input_node, graph, TFConvolution.get_all_op_aliases() + - TFStopMaskForwardOps.get_all_op_aliases() + - TFInput.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] - if any(t in sources_types for t in TFStopMaskForwardOps.get_all_op_aliases()): - return False - return True - - @classmethod - def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[tf.Tensor, None]: - """ - Generate output mask from input masks with all None replaced by identity masks. - If all input masks is None return None. - - :param node: Node to determine it's sources - :param graph: NNCF graph to work with - :return: Output mask - """ - input_edges = graph.get_input_edges(node) - previous_nodes = [edge.from_node for edge in input_edges] - input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] - - if all(mask is None for mask in input_masks): - return None - - device = [m for m in input_masks if m is not None][0].device - - filled_input_masks = [] - for i, mask in enumerate(input_masks): - if mask is None: - with tf.device(device): - mask = tf.ones(input_edges[i].tensor_shape[-1]) - filled_input_masks.append(mask) - result_mask = tf.concat(filled_input_masks, 0) - return result_mask +@TF_PRUNING_OPERATOR_METATYPES.register('elementwise') +class TFElementwisePruningOp(ElementwisePruningOp): + additional_types = _get_types(ELEMENTWISE_OPERATIONS) @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - result_mask = None + def _assert_input_masks_close(cls, input_masks): + for input_mask in input_masks[1:]: + tf.debugging.assert_near(input_masks[0], input_mask) - if cls.check_concat(node, graph): - result_mask = cls.generate_output_mask(node, graph) - node.data['output_mask'] = result_mask +@TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') +class TFStopMaskForwardPruningOp(StopMaskForwardPruningOp): + additional_types = ['Dense', 'MatMul'] -@TF_PRUNING_OPERATOR_METATYPES.register('elementwise') -class TFElementwise(DefaultMetaOp): - additional_types = _get_types(ELEMENTWISE_OPERATIONS) +@TF_PRUNING_OPERATOR_METATYPES.register('concat') +class TFConcatPruningOp(ConcatPruningOp): + additional_types = ['Concatenate', 'ConcatV2'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True + ConvolutionOp = TFConvolutionPruningOp + StopMaskForwardOp = TFStopMaskForwardPruningOp + InputOp = TFInputPruningOp @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - if input_masks[0] is not None: - for input_mask in input_masks[1:]: - tf.debugging.assert_near(input_masks[0], input_mask) - node.data['output_mask'] = input_masks[0] - - -@TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class TFStopMaskForwardOps(DefaultMetaOp): - additional_types = ['Dense', 'MatMul'] + def _get_unit_mask(cls, dim, device): + with tf.device(device): + mask = tf.ones(dim) + return mask @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False + def _get_masks_device(cls, input_masks): + return [m for m in input_masks if m is not None][0].device @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None + def _concat_masks(cls, filled_input_masks): + return tf.concat(filled_input_masks, 0) diff --git a/nncf/tensorflow/pruning/filter_pruning/algorithm.py b/nncf/tensorflow/pruning/filter_pruning/algorithm.py index 2146bb446b7..9fa2c15c5a1 100644 --- a/nncf/tensorflow/pruning/filter_pruning/algorithm.py +++ b/nncf/tensorflow/pruning/filter_pruning/algorithm.py @@ -52,9 +52,9 @@ from nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoController from nncf.tensorflow.pruning.base_algorithm import PrunedLayerInfo from nncf.tensorflow.pruning.export_helpers import TF_PRUNING_OPERATOR_METATYPES -from nncf.tensorflow.pruning.export_helpers import TFConvolution -from nncf.tensorflow.pruning.export_helpers import TFElementwise -from nncf.tensorflow.pruning.export_helpers import TFTransposeConvolution +from nncf.tensorflow.pruning.export_helpers import TFConvolutionPruningOp +from nncf.tensorflow.pruning.export_helpers import TFElementwisePruningOp +from nncf.tensorflow.pruning.export_helpers import TFTransposeConvolutionPruningOp from nncf.tensorflow.pruning.filter_pruning.functions import calculate_binary_mask from nncf.tensorflow.pruning.filter_pruning.functions import FILTER_IMPORTANCE_FUNCTIONS from nncf.tensorflow.pruning.filter_pruning.functions import tensor_l2_normalizer @@ -85,11 +85,11 @@ def _is_pruned_layer(self, layer: tf.keras.layers.Layer) -> bool: return layer.__class__.__name__ in self._prunable_types def _get_op_types_of_pruned_layers(self) -> List[str]: - return [op_name for meta_op in [TFConvolution, TFTransposeConvolution] + return [op_name for meta_op in [TFConvolutionPruningOp, TFTransposeConvolutionPruningOp] for op_name in meta_op.get_all_op_aliases()] def _get_types_of_grouping_ops(self) -> List[str]: - return TFElementwise.get_all_op_aliases() + return TFElementwisePruningOp.get_all_op_aliases() @ADAPTIVE_COMPRESSION_CONTROLLERS.register('tf_filter_pruning') diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 238020122b6..1b0e02f2ea1 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -55,7 +55,7 @@ def ignore_scope(cls): return cls -OP_NAMES_REQUIRING_MODULE_ATTRS = [v.op_func_name for v in NNCF_MODULES_DICT] + ["group_norm"] +OP_NAMES_REQUIRING_MODULE_ATTRS = [v.op_func_name for v in NNCF_MODULES_DICT] + ['group_norm'] def wrap_operator(operator, operator_info: 'PatchedOperatorInfo'): diff --git a/nncf/torch/graph/graph_builder.py b/nncf/torch/graph/graph_builder.py index c64b5cbf3ed..b22310beee9 100644 --- a/nncf/torch/graph/graph_builder.py +++ b/nncf/torch/graph/graph_builder.py @@ -22,10 +22,13 @@ from nncf.common.graph import INPUT_NOOP_METATYPES from nncf.common.graph import LayerName from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes +from nncf.common.graph.utils import get_concat_axis from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph_tracer import GraphTracer from nncf.torch.dynamic_graph.graph_tracer import ModelInputInfo from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.operator_metatypes import CatMetatype from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES @@ -91,4 +94,19 @@ def convert(dynamic_graph: DynamicGraph, input_infos: List[ModelInputInfo] = Non output_port_id=dynamic_graph_edge.output_port_id, dtype=Dtype.FLOAT ) + + for node in nncf_graph.get_all_nodes(): + if node.metatype is CatMetatype: + input_edges = nncf_graph.get_input_edges(node) + output_edges = nncf_graph.get_output_edges(node) + # Case of intermediate node + if input_edges and output_edges: + input_shapes = [edge.tensor_shape for edge in input_edges] + output_shapes = [edge.tensor_shape for edge in output_edges] + # Case node is stack + if len(input_shapes[0]) != len(output_shapes[0]): + continue + axis = get_concat_axis(input_shapes, output_shapes) + layer_attributes = MultipleInputLayerAttributes(axis) + node.layer_attributes = layer_attributes return nncf_graph diff --git a/nncf/torch/layers.py b/nncf/torch/layers.py index 47cb79de46a..fc27f93edb5 100644 --- a/nncf/torch/layers.py +++ b/nncf/torch/layers.py @@ -201,6 +201,7 @@ def from_module(module): dict_update(nncf_embedding_bag.__dict__, module.__dict__) return nncf_embedding_bag + NNCF_MODULES_DICT = { NNCFConv1d: nn.Conv1d, NNCFConv2d: nn.Conv2d, diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index 96047927374..2c19f387ac2 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -10,21 +10,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Union -from typing import List import torch from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp -from nncf.common.pruning.utils import is_grouped_conv -from nncf.common.pruning.utils import get_sources_of_node from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry -from nncf.common.pruning.mask_propagation import identity_mask_propagation -from nncf.common.pruning.mask_propagation import get_input_masks from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm -from nncf.common.graph.layer_attributes import GroupNormLayerAttributes from nncf.torch.graph.operator_metatypes import ( AddMetatype, AvgPool2dMetatype, @@ -57,6 +49,18 @@ SubMetatype, TanhMetatype, ) +from nncf.common.pruning.export_helpers import ( + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + GroupNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp +) from nncf.common.utils.logger import logger as nncf_logger from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT @@ -65,20 +69,7 @@ PT_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") -class PTDefaultMetaOp(DefaultMetaOp): - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - """ - Propagate mask through a node using masks of all inputs and pruning mask of current node (if any). - Should set the following attributes: - input_masks - list of masks of input nodes (None if there is no mask in some input); - output_mask - resulting mask of node operation. - - :param node: Node from NNCF graph to propagate mask through it. - :param graph: Graph of model to prune. - """ - raise NotImplementedError - +class PTPruner: @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): """ @@ -101,74 +92,26 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('model_input') -class PTInput(PTDefaultMetaOp): +class PTInputPruningOp(InputPruningOp, PTPruner): subtypes = [PTInputNoopMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['input_masks'] = [] - node.data['output_mask'] = None - @PT_PRUNING_OPERATOR_METATYPES.register('model_output') -class PTOutput(PTDefaultMetaOp): +class PTOutputPruningOp(OutputPruningOp, PTPruner): subtypes = [PTOutputNoopMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['input_masks'] = [] - node.data['output_mask'] = None - @PT_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class PTIdentityMaskForwardOps(PTDefaultMetaOp): +class PTIdentityMaskForwardPruningOp(IdentityMaskForwardPruningOp, PTPruner): subtypes = [HardTanhMetatype, TanhMetatype, RELUMetatype, PRELUMetatype, ELUMetatype, GELUMetatype, SigmoidMetatype, SoftmaxMetatype, AvgPool2dMetatype, MaxPool2dMetatype, DropoutMetatype] additional_types = ['h_sigmoid', 'h_swish', 'RELU'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @PT_PRUNING_OPERATOR_METATYPES.register('convolution') -class PTConvolution(PTDefaultMetaOp): +class PTConvolutionPruningOp(ConvolutionPruningOp, PTPruner): subtypes = [Conv1dMetatype, Conv2dMetatype, Conv3dMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -220,31 +163,9 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class PTTransposeConvolution(PTDefaultMetaOp): +class PTTransposeConvolutionPruningOp(TransposeConvolutionPruningOp, PTPruner): subtypes = [ConvTranspose2dMetatype, ConvTranspose3dMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -289,17 +210,9 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class PTBatchNorm(PTDefaultMetaOp): +class PTBatchNormPruningOp(BatchNormPruningOp, PTPruner): subtypes = [BatchNormMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -323,19 +236,9 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('group_norm') -class GroupNorm(PTDefaultMetaOp): +class PTGroupNormPruningOp(GroupNormPruningOp, PTPruner): subtypes = [GroupNormMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - # For Instance Normalization - return isinstance(node.layer_attributes, GroupNormLayerAttributes) \ - and node.layer_attributes.num_groups == node.layer_attributes.num_channels - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -358,93 +261,14 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): ' {}.'.format(node.data['key'], old_num_clannels, new_num_channels)) -@PT_PRUNING_OPERATOR_METATYPES.register('concat') -class PTConcat(PTDefaultMetaOp): - subtypes = [CatMetatype] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: - """ - Return whether all input sources of node is convolutions or not. - - :param node: Node to determine it's sources. - :param graph: NNCF graph to work with. - :return: True If all input sources of node is convolutions. - """ - - for input_node in graph.get_previous_nodes(node): - # If input has mask -> it went from convolution (source of this node is a convolution) - if input_node.data.get('output_mask', None) is None: - continue - - source_nodes = get_sources_of_node(input_node, graph, PTConvolution.get_all_op_aliases() + - PTStopMaskForwardOps.get_all_op_aliases() + - PTInput.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] - if any(t in sources_types for t in PTStopMaskForwardOps.get_all_op_aliases()): - return False - return True - - @classmethod - def fill_input_masks(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[torch.Tensor], None]: - """ - Fill input masks with all None replaced by identity masks. - If all input masks is None return None. - - :param node: Node to determine it's sources. - :param graph: NNCF graph to work with. - :return: Filled input masks. - """ - input_edges = graph.get_input_edges(node) - previous_nodes = [edge.from_node for edge in input_edges] - input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] - - if all(mask is None for mask in input_masks): - return None - - device = [m for m in input_masks if m is not None][0].device - - filled_input_masks = [] - for i, mask in enumerate(input_masks): - if mask is None: - mask = torch.ones(input_edges[i].tensor_shape[1], device=device) - filled_input_masks.append(mask) - return filled_input_masks - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = None - output_mask = None - - if cls.check_concat(node, graph): - input_masks = cls.fill_input_masks(node, graph) - if input_masks: - output_mask = torch.cat(input_masks) - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - - @PT_PRUNING_OPERATOR_METATYPES.register('elementwise') -class PTElementwise(PTDefaultMetaOp): +class PTElementwisePruningOp(ElementwisePruningOp, PTPruner): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True + def _assert_input_masks_close(cls, input_masks): + assert all(torch.allclose(input_masks[0], mask) for mask in input_masks) - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - - node.data['input_masks'] = input_masks - if input_masks[0] is not None: - assert all(torch.allclose(input_masks[0], mask) for mask in input_masks) - node.data['output_mask'] = input_masks[0] @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @@ -468,19 +292,29 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class PTStopMaskForwardOps(PTDefaultMetaOp): +class PTStopMaskForwardPruningOp(StopMaskForwardPruningOp, PTPruner): subtypes = [MeanMetatype, MaxMetatype, MinMetatype, LinearMetatype, MatMulMetatype] + +@PT_PRUNING_OPERATOR_METATYPES.register('concat') +class PTConcatPruningOp(ConcatPruningOp, PTPruner): + subtypes = [CatMetatype] + + ConvolutionOp = PTConvolutionPruningOp + StopMaskForwardOp = PTStopMaskForwardPruningOp + InputOp = PTInputPruningOp + @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False + def _get_unit_mask(cls, dim, device): + return torch.ones(dim, device=device) @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) + def _get_masks_device(cls, input_masks): + return [m for m in input_masks if m is not None][0].device - node.data['input_masks'] = input_masks - node.data['output_mask'] = None + @classmethod + def _concat_masks(cls, filled_input_masks): + return torch.cat(filled_input_masks, 0) class ModelPruner(MaskPropagationAlgorithm): @@ -495,7 +329,7 @@ def apply_mask(self): 1. running input_prune method for this node 2. running output_prune method for this node """ - pruned_node_modules = list() + pruned_node_modules = [] with torch.no_grad(): for node in self._graph.topological_sort(): node_cls = self.get_meta_operation_by_type_name(node.node_type) diff --git a/nncf/torch/pruning/filter_pruning/algo.py b/nncf/torch/pruning/filter_pruning/algo.py index 95c53a24fa7..54131acfff3 100644 --- a/nncf/torch/pruning/filter_pruning/algo.py +++ b/nncf/torch/pruning/filter_pruning/algo.py @@ -61,7 +61,7 @@ from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.pruning.base_algo import BasePruningAlgoBuilder from nncf.torch.pruning.base_algo import BasePruningAlgoController -from nncf.torch.pruning.export_helpers import PTElementwise +from nncf.torch.pruning.export_helpers import PTElementwisePruningOp from nncf.torch.pruning.export_helpers import PT_PRUNING_OPERATOR_METATYPES from nncf.torch.pruning.filter_pruning.functions import FILTER_IMPORTANCE_FUNCTIONS from nncf.torch.pruning.filter_pruning.functions import calculate_binary_mask @@ -109,7 +109,7 @@ def get_op_types_of_pruned_modules(self) -> List[str]: return types def get_types_of_grouping_ops(self) -> List[str]: - return PTElementwise.get_all_op_aliases() + return PTElementwisePruningOp.get_all_op_aliases() @ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_filter_pruning') diff --git a/tests/common/graph/test_utils.py b/tests/common/graph/test_utils.py new file mode 100644 index 00000000000..0f1b401cd76 --- /dev/null +++ b/tests/common/graph/test_utils.py @@ -0,0 +1,19 @@ +import pytest + +from nncf.common.graph.utils import get_concat_axis + + +TEST_CASES = [ + ([(1, 1), (1, 1)], [(2, 1)], [0]), + ([(None, 1, 1, 5)], [(None, 1, 1, 7)], [3, -1]), + ([(None, 1, 1, 5), (None, 1, 1, 5)], [(None, 1, 1, 10)], [3, -1]), + ([(1, 1, None), (1, 1, None)], [(1, 1, None)], [2, -1]), + ([(1, 1, 32, 1), (1, 1, 32, 1)], [(1, 1, 64, 1)], [2, -1]), + ([(1, 1, 5), (1, 1, 5)], [(1, 1, 5)], [-1]), +] + + +@pytest.mark.parametrize('input_shape,output_shape,possible_axes', TEST_CASES) +def test_get_concat_axis(input_shape, output_shape, possible_axes): + axis = get_concat_axis(input_shape, output_shape) + assert axis in possible_axes diff --git a/tests/common/pruning/dummy_types.py b/tests/common/pruning/dummy_types.py new file mode 100644 index 00000000000..faeb26a4584 --- /dev/null +++ b/tests/common/pruning/dummy_types.py @@ -0,0 +1,120 @@ +from typing import List + +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.export_helpers import ( + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + GroupNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp, +) + + +class DummyDefaultMetatype(OperatorMetatype): + name = None + + @classmethod + def get_all_aliases(cls) -> List[str]: + return [cls.name] + + +class DummyInputMetatype(OperatorMetatype): + name = 'input' + + +class DummyOutputMetatype(OperatorMetatype): + name = 'output' + + +class DymmyIdentityMaskForwardMetatype(OperatorMetatype): + name = 'identity_mask_forward' + + +class DummyElementwiseMetatype(OperatorMetatype): + name = 'elementwise' + + +class DummyConvMetatype(OperatorMetatype): + name = 'conv' + + +class DummyTransposeConvolutionMetatype(OperatorMetatype): + name = 'transpose_conv' + + +class DummyBatchNormMetatype(OperatorMetatype): + name = 'batch_norm' + + +class DummyGroupNormMetatype(OperatorMetatype): + name = 'group_norm' + + +class DummyConcatMetatype(OperatorMetatype): + name = 'concat' + + +class DummyStopPropoagtionMetatype(OperatorMetatype): + name = 'stop_propagation_ops' + + +DUMMY_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyInputMetatype.name) +class DummyInputPruningOp(InputPruningOp): + additional_types = [DummyInputMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyOutputMetatype.name) +class DummyOutputPruningOp(OutputPruningOp): + additional_types = [DummyOutputMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DymmyIdentityMaskForwardMetatype.name) +class DummyIdentityMaskForward(IdentityMaskForwardPruningOp): + additional_types = [DymmyIdentityMaskForwardMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyStopPropoagtionMetatype.name) +class DummyStopMaskForward(StopMaskForwardPruningOp): + additional_types = [DummyStopPropoagtionMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConvMetatype.name) +class DummyConvPruningOp(ConvolutionPruningOp): + additional_types = [DummyConvMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyTransposeConvolutionMetatype.name) +class DummyTransposeConvPruningOp(TransposeConvolutionPruningOp): + additional_types = [DummyTransposeConvolutionMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyBatchNormMetatype.name) +class DummyBatchNormPruningOp(BatchNormPruningOp): + additional_types = [DummyBatchNormMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyGroupNormMetatype.name) +class DummyGroupNormPruningOp(GroupNormPruningOp): + additional_types = [DummyGroupNormMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyElementwiseMetatype.name) +class DummyElementwisePruningOp(ElementwisePruningOp): + additional_types = [DummyElementwiseMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConcatMetatype.name) +class DummyConcatPruningOp(ConcatPruningOp): + ConvolutionOp = DummyConvPruningOp + StopMaskForwardOp = DummyStopMaskForward + InputOp = DummyInputPruningOp + additional_types = [DummyConcatMetatype.name] diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py new file mode 100644 index 00000000000..1b437cce286 --- /dev/null +++ b/tests/common/pruning/test_export_helpers.py @@ -0,0 +1,298 @@ +import numpy as np +import pytest + +from functools import partial + + +from tests.common.pruning import dummy_types +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes +from nncf.common.graph.layer_attributes import GroupNormLayerAttributes +from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.pruning.export_helpers import DefaultPruningOp +from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm + + +@pytest.mark.parametrize('dummy_op_class,accept_pruned_input', [(dummy_types.DummyInputPruningOp, False), + (dummy_types.DummyOutputPruningOp, True), + (dummy_types.DummyStopMaskForward, False)]) +def test_stop_propagate_ops(dummy_op_class, accept_pruned_input): + node = NNCFNode(0, 'dummy_node') + assert dummy_op_class.accept_pruned_input(node) == accept_pruned_input + dummy_op_class.mask_propagation(node, None) + assert node.data['output_mask'] is None + + +@pytest.mark.parametrize('dummy_op_class', [dummy_types.DummyIdentityMaskForward, dummy_types.DummyBatchNormPruningOp]) +def test_identity_mask_propogation_prune_ops(dummy_op_class): + assert dummy_op_class.accept_pruned_input(None) + graph = NNCFGraph() + conv_op = graph.add_nncf_node('conv_op', 'conv', dummy_types.DummyConvMetatype) + identity_ops = [] + for alias in dummy_op_class.get_all_op_aliases(): + identity_op = graph.add_nncf_node('identity', alias, dummy_types.DymmyIdentityMaskForwardMetatype) + graph.add_edge_between_nncf_nodes( + from_node_id=conv_op.node_id, + to_node_id=identity_op.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + identity_ops.append(identity_op) + # Check with and without masks + for output_mask in [None, np.ones((10,))]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + for identity_op in identity_ops: + identity_op = graph.get_node_by_id(identity_op.node_id) + assert np.all(identity_op.data['output_mask'] == output_mask) + + +@pytest.mark.parametrize('valid_masks', [None, True, False]) +def test_elementwise_prune_ops(valid_masks): + graph = NNCFGraph() + conv_op_0 = graph.add_nncf_node('conv_op_0', dummy_types.DummyConvMetatype.name, dummy_types.DummyConvMetatype) + conv_op_1 = graph.add_nncf_node('conv_op_1', dummy_types.DummyConvMetatype.name, dummy_types.DummyConvMetatype) + elementwise_op = graph.add_nncf_node('elementwise', dummy_types.DummyElementwiseMetatype.name, + dummy_types.DummyElementwiseMetatype) + add_node = partial(graph.add_edge_between_nncf_nodes, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # conv_op_0 -> elementwise + add_node(from_node_id=conv_op_0.node_id, + to_node_id=elementwise_op.node_id) + + # conv_op_1 -> elementwise + add_node(from_node_id=conv_op_1.node_id, + to_node_id=elementwise_op.node_id) + + masks = [np.ones((10,)), np.ones((10,))] if valid_masks is not None else None + + def set_masks(masks, ops): + for conv_op, mask in zip(ops, masks): + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = mask + + if valid_masks is None or valid_masks: + if valid_masks: + set_masks(masks, [conv_op_0, conv_op_1]) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + elementwise_op = graph.get_node_by_id(elementwise_op.node_id) + assert np.all(elementwise_op.data['output_mask'] == masks) + else: + def check_wrong_masks(masks): + with pytest.raises(AssertionError): + set_masks(masks, [conv_op_0, conv_op_1]) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + + masks[0][0] = 0 + check_wrong_masks(masks) + masks[0] = np.concatenate([masks[1], np.array([1])]) + check_wrong_masks(masks) + + +@pytest.mark.parametrize('num_channels,num_groups,accept_pruned_input_ref', [(10, 10, True), + (10, 5, False), + (10, 1, False)]) +def test_group_norm_pruning_ops(num_channels, num_groups, accept_pruned_input_ref): + graph = NNCFGraph() + conv_op = graph.add_nncf_node('conv_op', 'conv', dummy_types.DummyConvMetatype) + group_norm_layer_attributes = GroupNormLayerAttributes(True, num_channels=num_channels, + num_groups=num_groups) + group_norm_op = graph.add_nncf_node('identity', dummy_types.DummyGroupNormMetatype.name, + dummy_types.DummyGroupNormMetatype, + layer_attributes=group_norm_layer_attributes) + assert dummy_types.DummyGroupNormPruningOp.accept_pruned_input(group_norm_op) == accept_pruned_input_ref + graph.add_edge_between_nncf_nodes( + from_node_id=conv_op.node_id, + to_node_id=group_norm_op.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # Check with and without masks + for output_mask in [None, np.ones((10,))]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + identity_op = graph.get_node_by_id(group_norm_op.node_id) + assert np.all(identity_op.data['output_mask'] == output_mask) + + +class DummyMaskProducerMetatype(dummy_types.DummyDefaultMetatype): + name = 'mask_producer' + + +@dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyMaskProducerMetatype.name) +class MockOpMaskProducer(DefaultPruningOp): + additional_types = [DummyMaskProducerMetatype.name] + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + pass + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + pass + + +@pytest.mark.parametrize('transpose', [True, False], ids=['transpose', 'not_transpose']) +@pytest.mark.parametrize('layer_attributes,ref_accept_pruned_input,conv_type', [ + ({'in_channels': 5, + 'out_channels': 10, + 'groups': 1}, True, 'usual_conv'), + ({'in_channels': 10, + 'out_channels': 20, + 'groups': 5}, False, 'grouped_conv_no_depthwise'), + ({'in_channels': 10, + 'out_channels': 20, + 'groups': 10}, True, 'depthwise_conv') +], + ids=['usual_conv', + 'grouped_conv_no_depthwise', + 'depthwise_conv'] +) +def test_conv_pruning_ops(transpose, layer_attributes, ref_accept_pruned_input, conv_type): + default_conv_params = { + 'weight_requires_grad': True, + 'kernel_size': (2, 2), + 'stride': (1, 1), + 'padding_values': [0, 0] + } + graph = NNCFGraph() + dummy_op_before = graph.add_nncf_node('dummy_op_before', DummyMaskProducerMetatype.name, + DummyMaskProducerMetatype) + target_conv_attributes = ConvolutionLayerAttributes(transpose=transpose, **layer_attributes, **default_conv_params) + conv_op_target = graph.add_nncf_node('conv_op_target', dummy_types.DummyConvMetatype.name, + dummy_types.DummyConvMetatype, + layer_attributes=target_conv_attributes) + graph.add_edge_between_nncf_nodes(from_node_id=dummy_op_before.node_id, + to_node_id=conv_op_target.node_id, + tensor_shape=[layer_attributes['in_channels']] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + pruning_op_class = dummy_types.DummyTransposeConvPruningOp if transpose else dummy_types.DummyConvPruningOp + assert pruning_op_class.accept_pruned_input(conv_op_target) == ref_accept_pruned_input + ones_input_mask = np.ones((layer_attributes['in_channels'],)) + ones_output_mask = np.ones((layer_attributes['out_channels'],)) + # Check all combinations of masks + for input_mask in [None, ones_input_mask]: + for output_mask in [None, ones_output_mask]: + dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id) + conv_op_target = graph.get_node_by_id(conv_op_target.node_id) + dummy_op_before.data['output_mask'] = input_mask + conv_op_target.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id) + conv_op_target = graph.get_node_by_id(conv_op_target.node_id) + if conv_type == 'usual_conv': + assert np.all(conv_op_target.data['output_mask'] == output_mask) + elif conv_type == 'grouped_conv_no_depthwise': + assert conv_op_target.data['output_mask'] is None + else: + assert np.all(conv_op_target.data['output_mask'] == input_mask) + + +@pytest.mark.parametrize('with_elementwise', [False, True]) +def test_stop_ops_elementwise_source_before_concat(with_elementwise): + graph = NNCFGraph() + stop_op_0 = graph.add_nncf_node('stop_op_0', 'stop_propagation_ops', dummy_types.DummyStopPropoagtionMetatype) + stop_op_1 = graph.add_nncf_node('stop_op_1', 'stop_propagation_ops', dummy_types.DummyStopPropoagtionMetatype) + concat_layer_attributes = MultipleInputLayerAttributes(-1) + concat_node = graph.add_nncf_node('concat_node', 'concat', dummy_types.DummyConcatMetatype, + layer_attributes=concat_layer_attributes) + add_node = partial(graph.add_edge_between_nncf_nodes, + tensor_shape=[10, 10], + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + + if not with_elementwise: + # stop_op_0 -> concat_node + add_node(from_node_id=stop_op_0.node_id, + to_node_id=concat_node.node_id) + + # stop_op_1 -> concat_node + add_node(from_node_id=stop_op_1.node_id, + to_node_id=concat_node.node_id) + else: + elementwise_op = graph.add_nncf_node('elementwise', 'elementwise', dummy_types.DummyElementwiseMetatype) + + # stop_op_0 -> elementwise + add_node(from_node_id=stop_op_0.node_id, + to_node_id=elementwise_op.node_id) + + # stop_op_1 -> elementwise + add_node(from_node_id=stop_op_1.node_id, + to_node_id=elementwise_op.node_id) + + # elementwise -> concat + add_node(from_node_id=elementwise_op.node_id, + to_node_id=concat_node.node_id) + + assert not dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + concat_node = graph.get_node_by_id(concat_node.node_id) + assert concat_node.data['output_mask'] is None + + +@pytest.mark.parametrize('empty_mask_branch', [False, True]) +@pytest.mark.parametrize('right_branch_output_channels', [5, 10]) +def test_convs_elementwise_source_before_concat(empty_mask_branch, right_branch_output_channels): + graph = NNCFGraph() + conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv', dummy_types.DummyConvMetatype) + conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv', dummy_types.DummyConvMetatype) + conv_op_2 = graph.add_nncf_node('conv_op_2', 'conv', dummy_types.DummyConvMetatype) + elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', dummy_types.DummyElementwiseMetatype) + concat_layer_attributes = MultipleInputLayerAttributes(2) + concat_node = graph.add_nncf_node('concat_node', 'concat', dummy_types.DummyConcatMetatype, + layer_attributes=concat_layer_attributes) + add_node = partial(graph.add_edge_between_nncf_nodes, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + + # conv_op_0 -> elementwise_node + add_node(from_node_id=conv_op_0.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4) + + # conv_op_1 -> elementwise_node + add_node(from_node_id=conv_op_1.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4) + + # elementwise_node -> concat_node + add_node(from_node_id=elementwise_node.node_id, + to_node_id=concat_node.node_id, + tensor_shape=[10] * 4) + + # conv_op_2 -> concat_node + add_node(from_node_id=conv_op_2.node_id, + to_node_id=concat_node.node_id, + tensor_shape=[10, 10, right_branch_output_channels, 10]) + + # Check without masks + assert dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) + # Set masks + for conv_op in [conv_op_0, conv_op_1]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = np.ones(10) + + if not empty_mask_branch: + conv_op = graph.get_node_by_id(conv_op_2.node_id) + conv_op.data['output_mask'] = np.ones(right_branch_output_channels) + + # Propagate masks + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + # Check with masks + concat_node = graph.get_node_by_id(concat_node.node_id) + assert dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) + reference_mask = np.ones((10 + right_branch_output_channels,)) + np.testing.assert_equal(concat_node.data['output_mask'], reference_mask) diff --git a/tests/tensorflow/pruning/test_concat.py b/tests/tensorflow/pruning/test_concat.py new file mode 100644 index 00000000000..291793983dc --- /dev/null +++ b/tests/tensorflow/pruning/test_concat.py @@ -0,0 +1,25 @@ +import pytest + +from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test +from tests.tensorflow.pruning.helpers import get_basic_pruning_config +from tests.tensorflow import test_models + + +MODELS = [ + {'model': test_models.InceptionV3, + 'input_shape': (75, 75, 3)}, + {'model': test_models.InceptionResNetV2, + 'input_shape': (75, 75, 3)}, + {'model': test_models.NASNetMobile, + 'input_shape': (32, 32, 3)}, + {'model': test_models.DenseNet121, + 'input_shape': (32, 32, 3)}, +] + + +@pytest.mark.parametrize('model,input_shape', [list(elem.values()) for elem in MODELS]) +def test_concat(model, input_shape): + config = get_basic_pruning_config(input_shape[1]) + model = model(list(input_shape)) + + model, _ = create_compressed_model_and_algo_for_test(model, config) diff --git a/tests/tensorflow/test_model_converter.py b/tests/tensorflow/test_model_converter.py index e217b2df435..6402280ff9c 100644 --- a/tests/tensorflow/test_model_converter.py +++ b/tests/tensorflow/test_model_converter.py @@ -11,7 +11,9 @@ limitations under the License. """ +import pytest import tensorflow as tf +from functools import partial from tensorflow.python.keras import backend from tensorflow.python.keras import layers from tensorflow.python.keras import models @@ -19,11 +21,14 @@ from nncf.common.graph import INPUT_NOOP_METATYPES from nncf.common.graph import OUTPUT_NOOP_METATYPES +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.tensorflow.graph.converter import TFModelConverter from nncf.tensorflow.graph.converter import convert_keras_model_to_nncf_graph +from nncf.tensorflow.graph.metatypes.common import LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS from tests.tensorflow.helpers import get_basic_conv_test_model from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.quantization.test_algorithm_quantization import get_basic_quantization_config +from tests.tensorflow.pruning.helpers import get_concat_test_model def test_struct_auxiliary_nodes_nncf_graph(): @@ -84,3 +89,36 @@ def test_get_custom_layers(): assert len(custom_layers) == 1 assert CustomLayerForTest.CUSTOM_LAYER_NAME in custom_layers assert isinstance(custom_layers[CustomLayerForTest.CUSTOM_LAYER_NAME], CustomLayerForTest) + + +def get_model_with_reshapes_and_concats(batch_size=None): + inputs = layers.Input((64, ), batch_size=batch_size) + x = tf.reshape(inputs, (32, -1)) + x = layers.Reshape((16, -1))(x) + ones = tf.ones_like(x) + t1 = layers.concatenate([x, ones]) + # pylint: disable=E1120,E1123 + t2 = tf.concat([x, ones], axis=-1) + y = tf.concat([t1, t2], axis=-1) + return models.Model(inputs, y, name='ModelWithReshape') + + +CONCAT_MODELS = [partial(get_concat_test_model, input_shape=[1, 8, 8, 1]), + get_model_with_reshapes_and_concats] +REF_CONCAT_ATTRS = [{'tf_op_layer_tf_concat_1': {'axis': [-1, 3]}, + 'tf_op_layer_tf_concat_2': {'axis': [-1, 3]}}, + {'concatenate': {'axis': [-1, 2]}, + 'tf_op_layer_concat': {'axis': [-1, 2]}, + 'tf_op_layer_concat_1': {'axis': [-1, 2]}}] + + +@pytest.mark.parametrize('model, ref_attrs', list(zip(CONCAT_MODELS, REF_CONCAT_ATTRS))) +def test_model_with_reshape_and_concat(model, ref_attrs): + model = model() + graph = convert_keras_model_to_nncf_graph(model) + for node in graph.get_all_nodes(): + if node.metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + assert node.node_name in ref_attrs + assert node.layer_attributes is not None + assert isinstance(node.layer_attributes, MultipleInputLayerAttributes) + assert node.layer_attributes.axis in ref_attrs[node.node_name]['axis'] diff --git a/tests/torch/pruning/test_concat.py b/tests/torch/pruning/test_concat.py new file mode 100644 index 00000000000..b549f21fe9a --- /dev/null +++ b/tests/torch/pruning/test_concat.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from functools import partial + +from nncf.common.pruning.schedulers import BaselinePruningScheduler, ExponentialWithBiasPruningScheduler +from tests.torch.pruning.helpers import get_pruning_baseline_config, PruningTestModel, get_pruning_exponential_config +from tests.torch.helpers import create_compressed_model_and_algo_for_test +from tests.torch import test_models +from tests.torch.test_models.synthetic import EmbeddingCatLinearModel +from tests.torch.test_models.googlenet import GoogLeNet +from tests.torch.test_models.sr_small_model import SmallModel + + +MODELS = [ + {'model': EmbeddingCatLinearModel, + 'input_shape': (1, 10)}, + {'model': test_models.densenet_cifar, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.DPN26, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.DPN92, + 'input_shape': (1, 3, 32, 32)}, + {'model': GoogLeNet, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.inception_v3, + 'input_shape': (1, 3, 229, 229)}, + {'model': test_models.PNASNetA, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.PNASNetB, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.ShuffleNetG3, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.ShuffleNetG2, + 'input_shape': (1, 3, 32, 32)}, + {'model': partial(test_models.ShuffleNetV2, net_size=1), + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.squeezenet1_0, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.squeezenet1_1, + 'input_shape': (1, 3, 32, 32)}, + {'model': SmallModel, + 'input_shape': ()}, + {'model': test_models.UNet, + 'input_shape': (1, 3, 360, 480)}, +] + +SKIP_LIST = [SmallModel, EmbeddingCatLinearModel] + + +@pytest.mark.parametrize('model,input_shape', [list(elem.values()) for elem in MODELS]) +def test_models_with_concat(model, input_shape): + if model in SKIP_LIST: + pytest.skip() + + config = get_pruning_baseline_config(list(input_shape)) + config['compression']['algorithm'] = 'filter_pruning' + model = model() + _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) diff --git a/tests/torch/pruning/test_model_pruning_analysis.py b/tests/torch/pruning/test_model_pruning_analysis.py index a6a42f68277..03cd8fcdff1 100644 --- a/tests/torch/pruning/test_model_pruning_analysis.py +++ b/tests/torch/pruning/test_model_pruning_analysis.py @@ -30,8 +30,8 @@ from nncf.torch.dynamic_graph.graph_tracer import ModelInputInfo from nncf.torch.layers import NNCF_PRUNING_MODULES_DICT from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.pruning.export_helpers import PTElementwise -from nncf.torch.pruning.export_helpers import PTIdentityMaskForwardOps +from nncf.torch.pruning.export_helpers import PTElementwisePruningOp +from nncf.torch.pruning.export_helpers import PTIdentityMaskForwardPruningOp from nncf.torch.pruning.export_helpers import PT_PRUNING_OPERATOR_METATYPES from nncf.common.pruning.utils import is_depthwise_conv from nncf.torch.pruning.filter_pruning.algo import FilterPruningBuilder @@ -195,7 +195,7 @@ def test_pruning_node_selector(test_input_info_struct_: GroupPruningModulesTestS prune_first, prune_last, prune_downsample = test_input_info_struct_.prune_params pruning_operations = [v.op_func_name for v in NNCF_PRUNING_MODULES_DICT] - grouping_operations = PTElementwise.get_all_op_aliases() + grouping_operations = PTElementwisePruningOp.get_all_op_aliases() from nncf.common.pruning.pruning_node_selector import PruningNodeSelector pruning_node_selector = PruningNodeSelector(PT_PRUNING_OPERATOR_METATYPES, pruning_operations, @@ -264,7 +264,7 @@ def test_group_special_nodes(test_special_ops_struct: GroupSpecialModulesTestStr special_ops_clusterization = cluster_special_ops(nncf_model.get_original_graph(), algo_builder.get_types_of_grouping_ops(), - PTIdentityMaskForwardOps.get_all_op_aliases()) + PTIdentityMaskForwardPruningOp.get_all_op_aliases()) for ref_cluster in test_special_ops_struct.eltwise_clusters: cluster = special_ops_clusterization.get_cluster_containing_element(ref_cluster[0]) diff --git a/tests/torch/test_graph_building.py b/tests/torch/test_graph_building.py index 4eff40b434f..51e401f9e83 100644 --- a/tests/torch/test_graph_building.py +++ b/tests/torch/test_graph_building.py @@ -15,6 +15,7 @@ from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from typing import List from typing import Tuple @@ -201,6 +202,38 @@ def test_activation_shape_tracing(input_shape: Tuple): assert output_tensor_shapes == ref_output_shapes, "Failed for node ID: {}".format(node_id) +class ModelForTestWithReshapeFlattenAndConcat(ModelForTest): + def forward(self, x): + y = super().forward(x) + size = y.size() + y = y.view(size + (1, 1)) + y_copy = torch.ones_like(y) + y = torch.cat([y, y_copy], -1) + y = torch.flatten(y) + _ = y.view(-1) + y_copy = torch.ones_like(y) + y = torch.cat([y, y_copy], -1) + return y + + +@pytest.mark.parametrize("input_shape", input_shapes) +def test_concat_attributes_saved_during_graph_building(input_shape): + model = ModelForTestWithReshapeFlattenAndConcat() + input_info = ModelInputInfo(input_shape) + graph_builder = GraphBuilder(create_dummy_forward_fn([input_info, ], with_input_tracing=True, + with_output_tracing=True)) + graph = graph_builder.build_graph(model) + reshape_nodes_with_attributes = { + 'ModelForTestWithReshapeFlattenAndConcat/cat_0': {'axis': 1}, + 'ModelForTestWithReshapeFlattenAndConcat/cat_1': {'axis': 5}, + 'ModelForTestWithReshapeFlattenAndConcat/cat_2': {'axis': 0} + } + for name, shapes in reshape_nodes_with_attributes.items(): + node = graph.get_node_by_name(name) + assert isinstance(node.layer_attributes, MultipleInputLayerAttributes) + assert node.layer_attributes.axis == shapes['axis'] + + TEST_KEYWORD_1 = "keyword1" TEST_KEYWORD_2 = "keyword2" INPUT_INFO_CONFIG_VS_FORWARD_ARGS = [