From db939dc3274aefaed4f7539f9cb9b5634f8941f3 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Thu, 30 Sep 2021 17:02:30 +0300 Subject: [PATCH 01/19] Refactor export helpers --- nncf/common/pruning/export_helpers.py | 243 +++++++++++++++++++ nncf/common/pruning/mask_propagation.py | 3 +- nncf/common/pruning/model_analysis.py | 21 ++ nncf/common/pruning/pruning_node_selector.py | 2 +- nncf/tensorflow/pruning/export_helpers.py | 174 ++++--------- nncf/torch/pruning/export_helpers.py | 207 +++------------- 6 files changed, 345 insertions(+), 305 deletions(-) diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index 1187def3e77..5feb6168901 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -10,8 +10,19 @@ See the License for the specific language governing permissions and limitations under the License. """ + +import numpy as np + +from typing import Union + from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode +from nncf.common.pruning.utils import is_grouped_conv +from nncf.common.pruning.utils import get_sources_of_node +from nncf.common.pruning.utils import is_depthwise_conv +from nncf.common.graph.layer_attributes import GroupNormLayerAttributes +from nncf.common.pruning.mask_propagation import identity_mask_propagation +from nncf.common.pruning.mask_propagation import get_input_masks class DefaultMetaOp: @@ -50,3 +61,235 @@ def get_all_op_aliases(cls): op_types.extend(subtype.get_all_aliases()) op_types = list(set(op_types)) + cls.additional_types return op_types + + +class OpInput(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return False + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + node.data['output_mask'] = None + + +class OpOutput(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + node.data['output_mask'] = None + + +class OpIdentityMaskForwardOps(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + identity_mask_propagation(node, graph) + + +class OpConvolution(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + accept_pruned_input = True + if is_grouped_conv(node): + if not is_depthwise_conv(node): + accept_pruned_input = False + return accept_pruned_input + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + input_masks = get_input_masks(node, graph) + output_mask = node.data.get('output_mask', None) + + if is_grouped_conv(node): + output_mask = None + if is_depthwise_conv(node): + output_mask = input_masks[0] + + node.data['output_mask'] = output_mask + + +class OpTransposeConvolution(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + accept_pruned_input = True + if is_grouped_conv(node): + if not is_depthwise_conv(node): + accept_pruned_input = False + return accept_pruned_input + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + input_masks = get_input_masks(node, graph) + output_mask = node.data.get('output_mask', None) + + # In case of group convs we can't prune by output filters + if is_grouped_conv(node): + output_mask = None + if is_depthwise_conv(node): + output_mask = input_masks[0] + + node.data['output_mask'] = output_mask + + +class OpBatchNorm(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + identity_mask_propagation(node, graph) + + +class OpGroupNorm(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + # For Instance Normalization + return isinstance(node.layer_attributes, GroupNormLayerAttributes) \ + and node.layer_attributes.num_groups == node.layer_attributes.num_channels + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + identity_mask_propagation(node, graph) + + +class OpConcat(DefaultMetaOp): + ConvolutionOp = None # type: OpConvolution + StopMaskForwardOp = None # type: OpStopMaskForwardOps + InputOp = None # type: OpInput + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: + """ + Return whether all input sources of node is convolutions or not. + + :param node: Node to determine it's sources + :param graph: NNCF graph to work with + :return: True if all input sources of node is convolutions + """ + + for input_node in graph.get_previous_nodes(node): + # If input has mask -> it went from convolution (source of this node is a convolution) + if input_node.data.get('output_mask', None) is None: + continue + + source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() + + cls.StopMaskForwardOp.get_all_op_aliases() + + cls.InputOp.get_all_op_aliases()) + sources_types = [node.node_type for node in source_nodes] + if any(t in sources_types for t in cls.StopMaskForwardOp.get_all_op_aliases()): + return False + return True + + @classmethod + def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[np.array, None]: + """ + Generate output mask from input masks with all None replaced by identity masks. + If all input masks is None return None. + + :param node: Node to determine it's sources + :param graph: NNCF graph to work with + :return: Output mask + """ + input_edges = graph.get_input_edges(node) + previous_nodes = [edge.from_node for edge in input_edges] + input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] + + if all(mask is None for mask in input_masks): + return None + + + filled_input_masks = [] + for i, mask in enumerate(input_masks): + if mask is None: + mask = np.ones(input_edges[i].tensor_shape[-1]) + filled_input_masks.append(mask) + result_mask = np.concatenate(filled_input_masks, 0) + return result_mask + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + result_mask = None + + if cls.check_concat(node, graph): + result_mask = cls.generate_output_mask(node, graph) + + node.data['output_mask'] = result_mask + + +class OpElementwise(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + input_masks = get_input_masks(node, graph) + if input_masks[0] is not None: + for input_mask in input_masks[1:]: + np.testing.assert_allclose(input_masks[0], input_mask) + node.data['output_mask'] = input_masks[0] + + +class OpReshape(DefaultMetaOp): + @staticmethod + def _is_flatten(node: NNCFNode): + return sum([dim for dim in node.layer_attributes.output_shape if dim]) == 1 + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + if node.layer_attributes is None: + return False + return True + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + if cls.accept_pruned_input(node): + if cls._is_flatten(node): + OpFlatten.mask_propagation(node, graph) + else: + identity_mask_propagation(node, graph) + else: + node.data['output_mask'] = None + + +class OpFlatten(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + if node.layer_attributes is not None: + return True + return False + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + output_mask = None + input_masks = get_input_masks(node, graph) + assert len(input_masks) == 1 + input_mask = input_masks[0] + if input_mask is not None and node.layer_attributes is not None: + flatten_channels = int(np.prod(node.layer_attributes.input_shape)) + mask_len = input_mask.shape[0] + assert flatten_channels % mask_len == 0 + output_mask = np.repeat(input_mask, repeats=flatten_channels // mask_len) + node.data['output_mask'] = output_mask + + +class OpStopMaskForwardOps(DefaultMetaOp): + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + return False + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + node.data['output_mask'] = None diff --git a/nncf/common/pruning/mask_propagation.py b/nncf/common/pruning/mask_propagation.py index aa1a13fe879..7fdb3da5abe 100644 --- a/nncf/common/pruning/mask_propagation.py +++ b/nncf/common/pruning/mask_propagation.py @@ -16,7 +16,6 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry TensorType = TypeVar('TensorType') @@ -40,7 +39,7 @@ def __init__(self, graph: NNCFGraph, pruning_operator_metatypes: PruningOperatio self._graph = graph self._pruning_operator_metatypes = pruning_operator_metatypes - def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp: + def get_meta_operation_by_type_name(self, type_name: str) -> 'DefaultMetaOp': """ Returns class of metaop that corresponds to `type_name` type. diff --git a/nncf/common/pruning/model_analysis.py b/nncf/common/pruning/model_analysis.py index 4df7f4870a4..7ee65b2971a 100644 --- a/nncf/common/pruning/model_analysis.py +++ b/nncf/common/pruning/model_analysis.py @@ -11,6 +11,8 @@ limitations under the License. """ +import numpy as np + from typing import Callable, List from nncf.common.graph import NNCFGraph @@ -22,6 +24,16 @@ from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +class SymbolicMask(np.ndarray): + def __init__(self, mask_producer: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self._mask_producer = mask_producer + + @property + def mask_producer(self): + return self._mask_producer + + def get_position(nodes_list: List[NNCFNode], idx: int): for i, node in enumerate(nodes_list): if node.node_id == idx: @@ -108,9 +120,11 @@ class ModelAnalyzer: """ def __init__(self, graph: NNCFGraph, + prune_operations: List[str], pruning_operator_metatypes: PruningOperationsMetatypeRegistry, is_depthwise_conv_fn: Callable[[NNCFNode], bool]): self.graph = graph + self._prune_operations = prune_operations self._pruning_operator_metatypes = pruning_operator_metatypes pruning_op_metatypes_dict = self._pruning_operator_metatypes.registry_dict @@ -194,8 +208,15 @@ def set_accept_pruned_input_attr(self): cls = self.get_meta_operation_by_type_name(nncf_node.node_type) self.accept_pruned_input[nncf_node.node_id] = cls.accept_pruned_input(nncf_node) + def check_pruned_dimentions(self): + # Init output_masks for each prunable layer + for node in self.graph.get_all_nodes(): + if node.node_type in self._prune_operations and self.can_prune[node.node_id]: + node.data['output_mask'] = None + def analyse_model_before_pruning(self): self.set_accept_pruned_input_attr() self.propagate_can_prune_attr_up() self.propagate_can_prune_attr_down() + self.check_pruned_dimentions() return self.can_prune diff --git a/nncf/common/pruning/pruning_node_selector.py b/nncf/common/pruning/pruning_node_selector.py index 6c546f546ff..1b8727e8868 100644 --- a/nncf/common/pruning/pruning_node_selector.py +++ b/nncf/common/pruning/pruning_node_selector.py @@ -168,7 +168,7 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]: pruned_nodes_clusterization.merge_list_of_clusters(previous_clusters) # 6. Checks for groups (all nodes in group can be pruned or all group can't be pruned). - model_analyser = ModelAnalyzer(graph, self._pruning_operator_metatypes, is_depthwise_conv) + model_analyser = ModelAnalyzer(graph, self._prune_operations, self._pruning_operator_metatypes, is_depthwise_conv) can_prune_analysis = model_analyser.analyse_model_before_pruning() self._check_pruning_groups(graph, pruned_nodes_clusterization, can_prune_analysis) return pruned_nodes_clusterization diff --git a/nncf/tensorflow/pruning/export_helpers.py b/nncf/tensorflow/pruning/export_helpers.py index fac53e8fbda..408fd3f9898 100644 --- a/nncf/tensorflow/pruning/export_helpers.py +++ b/nncf/tensorflow/pruning/export_helpers.py @@ -16,19 +16,27 @@ import tensorflow as tf -from nncf.common.pruning.utils import is_depthwise_conv from nncf.tensorflow.graph.pattern_operations import KERAS_ACTIVATIONS_OPERATIONS from nncf.tensorflow.graph.pattern_operations import ELEMENTWISE_OPERATIONS from nncf.tensorflow.graph.pattern_operations import TF_ACTIVATIONS_OPERATIONS from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp -from nncf.common.pruning.mask_propagation import identity_mask_propagation from nncf.common.pruning.mask_propagation import get_input_masks -from nncf.common.pruning.utils import get_sources_of_node -from nncf.common.pruning.utils import is_grouped_conv from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.export_helpers import ( + OpInput, + OpOutput, + OpIdentityMaskForwardOps, + OpConvolution, + OpTransposeConvolution, + OpBatchNorm, + OpConcat, + OpElementwise, + OpReshape, + OpFlatten, + OpStopMaskForwardOps +) TF_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") @@ -38,139 +46,73 @@ def _get_types(operations_dict: Dict) -> List[str]: @TF_PRUNING_OPERATOR_METATYPES.register('model_input') -class TFInput(DefaultMetaOp): +class TFInput(OpInput): additional_types = ['InputLayer', NNCFGraphNodeType.INPUT_NODE] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None @TF_PRUNING_OPERATOR_METATYPES.register('model_output') -class TFOutput(DefaultMetaOp): +class TFOutput(OpOutput): additional_types = [NNCFGraphNodeType.OUTPUT_NODE] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None @TF_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class TFIdentityMaskForwardOps(DefaultMetaOp): +class TFIdentityMaskForwardOps(OpIdentityMaskForwardOps): additional_types = _get_types(KERAS_ACTIVATIONS_OPERATIONS) + _get_types(TF_ACTIVATIONS_OPERATIONS) \ + ['AvgPool2D', 'GlobalAvgPool2D', 'AveragePooling2D', 'GlobalAveragePooling2D'] \ + ['MaxPooling2D', 'GlobalMaxPooling2D', 'MaxPool2D', 'GlobalMaxPool2D'] \ + ['Dropout', 'ZeroPadding2D', 'Identity', 'Pad', 'UpSampling2D'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @TF_PRUNING_OPERATOR_METATYPES.register('convolution') -class TFConvolution(DefaultMetaOp): +class TFConvolution(OpConvolution): additional_types = ['Conv1D', 'Conv2D', 'Conv3D', 'DepthwiseConv2D'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] +@TF_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') +class TFTransposeConvolution(OpTransposeConvolution): + additional_types = ['Conv1DTranspose', 'Conv2DTranspose', 'Conv3DTranspose'] - node.data['output_mask'] = output_mask +@TF_PRUNING_OPERATOR_METATYPES.register('batch_norm') +class TFBatchNorm(OpBatchNorm): + additional_types = ['BatchNormalization', 'SyncBatchNormalization'] -@TF_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class TFTransposeConvolution(DefaultMetaOp): - additional_types = ['Conv1DTranspose', 'Conv2DTranspose', 'Conv3DTranspose'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input +@TF_PRUNING_OPERATOR_METATYPES.register('elementwise') +class TFElementwise(OpElementwise): + additional_types = _get_types(ELEMENTWISE_OPERATIONS) @classmethod def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) + if input_masks[0] is not None: + for input_mask in input_masks[1:]: + tf.debugging.assert_near(input_masks[0], input_mask) + node.data['output_mask'] = input_masks[0] - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - node.data['output_mask'] = output_mask +@TF_PRUNING_OPERATOR_METATYPES.register('reshape') +class TFReshapeOps(OpReshape): + additional_types = ['Reshape'] -@TF_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class TFBatchNorm(DefaultMetaOp): - additional_types = ['BatchNormalization', 'SyncBatchNormalization'] +@TF_PRUNING_OPERATOR_METATYPES.register('flatten') +class TFFlattenOps(OpFlatten): + additional_types = ['Flatten'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) +@TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') +class TFStopMaskForwardOps(OpStopMaskForwardOps): + additional_types = ['Dense', 'MatMul'] @TF_PRUNING_OPERATOR_METATYPES.register('concat') -class TFConcat(DefaultMetaOp): +class TFConcat(OpConcat): additional_types = ['Concatenate', 'ConcatV2'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: - """ - Return whether all input sources of node is convolutions or not. - - :param node: Node to determine it's sources - :param graph: NNCF graph to work with - :return: True if all input sources of node is convolutions - """ - - for input_node in graph.get_previous_nodes(node): - # If input has mask -> it went from convolution (source of this node is a convolution) - if input_node.data.get('output_mask', None) is None: - continue - - source_nodes = get_sources_of_node(input_node, graph, TFConvolution.get_all_op_aliases() + - TFStopMaskForwardOps.get_all_op_aliases() + - TFInput.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] - if any(t in sources_types for t in TFStopMaskForwardOps.get_all_op_aliases()): - return False - return True + ConvolutionOp = TFConvolution + StopMaskForwardOp = TFStopMaskForwardOps + InputOp = TFInput @classmethod def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[tf.Tensor, None]: @@ -208,33 +150,3 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): result_mask = cls.generate_output_mask(node, graph) node.data['output_mask'] = result_mask - - -@TF_PRUNING_OPERATOR_METATYPES.register('elementwise') -class TFElementwise(DefaultMetaOp): - additional_types = _get_types(ELEMENTWISE_OPERATIONS) - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - if input_masks[0] is not None: - for input_mask in input_masks[1:]: - tf.debugging.assert_near(input_masks[0], input_mask) - node.data['output_mask'] = input_masks[0] - - -@TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class TFStopMaskForwardOps(DefaultMetaOp): - additional_types = ['Dense', 'MatMul'] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['output_mask'] = None diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index 96047927374..26dd7ac5183 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -12,8 +12,10 @@ """ from typing import Union from typing import List +from collections import Counter import torch +import numpy as np from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -56,6 +58,20 @@ SoftmaxMetatype, SubMetatype, TanhMetatype, + ReshapeMetatype, +) +from nncf.common.pruning.export_helpers import ( + OpInput, + OpOutput, + OpIdentityMaskForwardOps, + OpConvolution, + OpTransposeConvolution, + OpBatchNorm, + OpGroupNorm, + OpConcat, + OpElementwise, + OpReshape, + OpStopMaskForwardOps ) from nncf.common.utils.logger import logger as nncf_logger from nncf.torch.nncf_network import NNCFNetwork @@ -65,110 +81,32 @@ PT_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") -class PTDefaultMetaOp(DefaultMetaOp): - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - """ - Propagate mask through a node using masks of all inputs and pruning mask of current node (if any). - Should set the following attributes: - input_masks - list of masks of input nodes (None if there is no mask in some input); - output_mask - resulting mask of node operation. - - :param node: Node from NNCF graph to propagate mask through it. - :param graph: Graph of model to prune. - """ - raise NotImplementedError - - @classmethod - def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): - """ - Prune node by input_masks (if masks is not none and operation support it). - - :param model: NNCF network. - :param node: Node from NNCF graph that will be prune. - :param graph: Graph of model. - """ - - @classmethod - def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): - """ - Prune node by output_mask (if mask is not none and operation support it). - - :param model: NNCF network. - :param node: Node from NNCF graph that will be prune. - :param graph: Graph of model. - """ - - @PT_PRUNING_OPERATOR_METATYPES.register('model_input') -class PTInput(PTDefaultMetaOp): +class PTInput(OpInput): subtypes = [PTInputNoopMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['input_masks'] = [] - node.data['output_mask'] = None - @PT_PRUNING_OPERATOR_METATYPES.register('model_output') -class PTOutput(PTDefaultMetaOp): +class PTOutput(OpOutput): subtypes = [PTOutputNoopMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - node.data['input_masks'] = [] - node.data['output_mask'] = None - @PT_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class PTIdentityMaskForwardOps(PTDefaultMetaOp): +class PTIdentityMaskForwardOps(OpIdentityMaskForwardOps): subtypes = [HardTanhMetatype, TanhMetatype, RELUMetatype, PRELUMetatype, ELUMetatype, GELUMetatype, SigmoidMetatype, SoftmaxMetatype, AvgPool2dMetatype, MaxPool2dMetatype, DropoutMetatype] additional_types = ['h_sigmoid', 'h_swish', 'RELU'] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) +@PT_PRUNING_OPERATOR_METATYPES.register('reshape') +class PTReshape(OpReshape): + subtypes = [ReshapeMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('convolution') -class PTConvolution(PTDefaultMetaOp): +class PTConvolution(OpConvolution): subtypes = [Conv1dMetatype, Conv2dMetatype, Conv3dMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -220,31 +158,9 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class PTTransposeConvolution(PTDefaultMetaOp): +class PTTransposeConvolution(OpTransposeConvolution): subtypes = [ConvTranspose2dMetatype, ConvTranspose3dMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - accept_pruned_input = True - if is_grouped_conv(node): - if not is_depthwise_conv(node): - accept_pruned_input = False - return accept_pruned_input - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - output_mask = node.data.get('output_mask', None) - - # In case of group convs we can't prune by output filters - if is_grouped_conv(node): - output_mask = None - if is_depthwise_conv(node): - output_mask = input_masks[0] - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -289,17 +205,9 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class PTBatchNorm(PTDefaultMetaOp): +class PTBatchNorm(OpBatchNorm): subtypes = [BatchNormMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -323,19 +231,9 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('group_norm') -class GroupNorm(PTDefaultMetaOp): +class GroupNorm(OpGroupNorm): subtypes = [GroupNormMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - # For Instance Normalization - return isinstance(node.layer_attributes, GroupNormLayerAttributes) \ - and node.layer_attributes.num_groups == node.layer_attributes.num_channels - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - identity_mask_propagation(node, graph) - @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): input_mask = node.data['input_masks'][0] @@ -359,36 +257,9 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('concat') -class PTConcat(PTDefaultMetaOp): +class PTConcat(OpConcat): subtypes = [CatMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - - @classmethod - def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: - """ - Return whether all input sources of node is convolutions or not. - - :param node: Node to determine it's sources. - :param graph: NNCF graph to work with. - :return: True If all input sources of node is convolutions. - """ - - for input_node in graph.get_previous_nodes(node): - # If input has mask -> it went from convolution (source of this node is a convolution) - if input_node.data.get('output_mask', None) is None: - continue - - source_nodes = get_sources_of_node(input_node, graph, PTConvolution.get_all_op_aliases() + - PTStopMaskForwardOps.get_all_op_aliases() + - PTInput.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] - if any(t in sources_types for t in PTStopMaskForwardOps.get_all_op_aliases()): - return False - return True - @classmethod def fill_input_masks(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[torch.Tensor], None]: """ @@ -430,13 +301,9 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('elementwise') -class PTElementwise(PTDefaultMetaOp): +class PTElementwise(OpElementwise): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return True - @classmethod def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): input_masks = get_input_masks(node, graph) @@ -468,19 +335,17 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class PTStopMaskForwardOps(PTDefaultMetaOp): +class PTStopMaskForwardOps(OpStopMaskForwardOps): subtypes = [MeanMetatype, MaxMetatype, MinMetatype, LinearMetatype, MatMulMetatype] - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - return False - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) +@PT_PRUNING_OPERATOR_METATYPES.register('concat') +class PTConcat(OpConcat): + subtypes = [CatMetatype] - node.data['input_masks'] = input_masks - node.data['output_mask'] = None + ConvolutionOp = PTConvolution + StopMaskForwardOp = PTStopMaskForwardOps + InputOp = PTInput class ModelPruner(MaskPropagationAlgorithm): @@ -495,7 +360,7 @@ def apply_mask(self): 1. running input_prune method for this node 2. running output_prune method for this node """ - pruned_node_modules = list() + pruned_node_modules = [] with torch.no_grad(): for node in self._graph.topological_sort(): node_cls = self.get_meta_operation_by_type_name(node.node_type) From 4e617a949692fd6ec43e2f8263fdb7f67fe90feb Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 1 Oct 2021 10:37:19 +0300 Subject: [PATCH 02/19] Remove flatten and reshape ops to clean up pr --- nncf/common/pruning/export_helpers.py | 43 --------- nncf/common/pruning/model_analysis.py | 21 ----- nncf/common/pruning/pruning_node_selector.py | 2 +- nncf/tensorflow/pruning/export_helpers.py | 12 --- nncf/torch/pruning/export_helpers.py | 93 +++++++++----------- 5 files changed, 43 insertions(+), 128 deletions(-) diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index 5feb6168901..88c490b01b5 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -242,49 +242,6 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = input_masks[0] -class OpReshape(DefaultMetaOp): - @staticmethod - def _is_flatten(node: NNCFNode): - return sum([dim for dim in node.layer_attributes.output_shape if dim]) == 1 - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - if node.layer_attributes is None: - return False - return True - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - if cls.accept_pruned_input(node): - if cls._is_flatten(node): - OpFlatten.mask_propagation(node, graph) - else: - identity_mask_propagation(node, graph) - else: - node.data['output_mask'] = None - - -class OpFlatten(DefaultMetaOp): - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - if node.layer_attributes is not None: - return True - return False - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - output_mask = None - input_masks = get_input_masks(node, graph) - assert len(input_masks) == 1 - input_mask = input_masks[0] - if input_mask is not None and node.layer_attributes is not None: - flatten_channels = int(np.prod(node.layer_attributes.input_shape)) - mask_len = input_mask.shape[0] - assert flatten_channels % mask_len == 0 - output_mask = np.repeat(input_mask, repeats=flatten_channels // mask_len) - node.data['output_mask'] = output_mask - - class OpStopMaskForwardOps(DefaultMetaOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): diff --git a/nncf/common/pruning/model_analysis.py b/nncf/common/pruning/model_analysis.py index 7ee65b2971a..4df7f4870a4 100644 --- a/nncf/common/pruning/model_analysis.py +++ b/nncf/common/pruning/model_analysis.py @@ -11,8 +11,6 @@ limitations under the License. """ -import numpy as np - from typing import Callable, List from nncf.common.graph import NNCFGraph @@ -24,16 +22,6 @@ from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry -class SymbolicMask(np.ndarray): - def __init__(self, mask_producer: int, *args, **kwargs): - super().__init__(*args, **kwargs) - self._mask_producer = mask_producer - - @property - def mask_producer(self): - return self._mask_producer - - def get_position(nodes_list: List[NNCFNode], idx: int): for i, node in enumerate(nodes_list): if node.node_id == idx: @@ -120,11 +108,9 @@ class ModelAnalyzer: """ def __init__(self, graph: NNCFGraph, - prune_operations: List[str], pruning_operator_metatypes: PruningOperationsMetatypeRegistry, is_depthwise_conv_fn: Callable[[NNCFNode], bool]): self.graph = graph - self._prune_operations = prune_operations self._pruning_operator_metatypes = pruning_operator_metatypes pruning_op_metatypes_dict = self._pruning_operator_metatypes.registry_dict @@ -208,15 +194,8 @@ def set_accept_pruned_input_attr(self): cls = self.get_meta_operation_by_type_name(nncf_node.node_type) self.accept_pruned_input[nncf_node.node_id] = cls.accept_pruned_input(nncf_node) - def check_pruned_dimentions(self): - # Init output_masks for each prunable layer - for node in self.graph.get_all_nodes(): - if node.node_type in self._prune_operations and self.can_prune[node.node_id]: - node.data['output_mask'] = None - def analyse_model_before_pruning(self): self.set_accept_pruned_input_attr() self.propagate_can_prune_attr_up() self.propagate_can_prune_attr_down() - self.check_pruned_dimentions() return self.can_prune diff --git a/nncf/common/pruning/pruning_node_selector.py b/nncf/common/pruning/pruning_node_selector.py index 1b8727e8868..6c546f546ff 100644 --- a/nncf/common/pruning/pruning_node_selector.py +++ b/nncf/common/pruning/pruning_node_selector.py @@ -168,7 +168,7 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]: pruned_nodes_clusterization.merge_list_of_clusters(previous_clusters) # 6. Checks for groups (all nodes in group can be pruned or all group can't be pruned). - model_analyser = ModelAnalyzer(graph, self._prune_operations, self._pruning_operator_metatypes, is_depthwise_conv) + model_analyser = ModelAnalyzer(graph, self._pruning_operator_metatypes, is_depthwise_conv) can_prune_analysis = model_analyser.analyse_model_before_pruning() self._check_pruning_groups(graph, pruned_nodes_clusterization, can_prune_analysis) return pruned_nodes_clusterization diff --git a/nncf/tensorflow/pruning/export_helpers.py b/nncf/tensorflow/pruning/export_helpers.py index 408fd3f9898..5d3387be54c 100644 --- a/nncf/tensorflow/pruning/export_helpers.py +++ b/nncf/tensorflow/pruning/export_helpers.py @@ -33,8 +33,6 @@ OpBatchNorm, OpConcat, OpElementwise, - OpReshape, - OpFlatten, OpStopMaskForwardOps ) @@ -91,16 +89,6 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = input_masks[0] -@TF_PRUNING_OPERATOR_METATYPES.register('reshape') -class TFReshapeOps(OpReshape): - additional_types = ['Reshape'] - - -@TF_PRUNING_OPERATOR_METATYPES.register('flatten') -class TFFlattenOps(OpFlatten): - additional_types = ['Flatten'] - - @TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') class TFStopMaskForwardOps(OpStopMaskForwardOps): additional_types = ['Dense', 'MatMul'] diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index 26dd7ac5183..f470e8e60aa 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -58,7 +58,6 @@ SoftmaxMetatype, SubMetatype, TanhMetatype, - ReshapeMetatype, ) from nncf.common.pruning.export_helpers import ( OpInput, @@ -70,7 +69,6 @@ OpGroupNorm, OpConcat, OpElementwise, - OpReshape, OpStopMaskForwardOps ) from nncf.common.utils.logger import logger as nncf_logger @@ -98,11 +96,6 @@ class PTIdentityMaskForwardOps(OpIdentityMaskForwardOps): additional_types = ['h_sigmoid', 'h_swish', 'RELU'] -@PT_PRUNING_OPERATOR_METATYPES.register('reshape') -class PTReshape(OpReshape): - subtypes = [ReshapeMetatype] - - @PT_PRUNING_OPERATOR_METATYPES.register('convolution') class PTConvolution(OpConvolution): subtypes = [Conv1dMetatype, Conv2dMetatype, Conv3dMetatype] @@ -256,50 +249,6 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): ' {}.'.format(node.data['key'], old_num_clannels, new_num_channels)) -@PT_PRUNING_OPERATOR_METATYPES.register('concat') -class PTConcat(OpConcat): - subtypes = [CatMetatype] - - @classmethod - def fill_input_masks(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[torch.Tensor], None]: - """ - Fill input masks with all None replaced by identity masks. - If all input masks is None return None. - - :param node: Node to determine it's sources. - :param graph: NNCF graph to work with. - :return: Filled input masks. - """ - input_edges = graph.get_input_edges(node) - previous_nodes = [edge.from_node for edge in input_edges] - input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] - - if all(mask is None for mask in input_masks): - return None - - device = [m for m in input_masks if m is not None][0].device - - filled_input_masks = [] - for i, mask in enumerate(input_masks): - if mask is None: - mask = torch.ones(input_edges[i].tensor_shape[1], device=device) - filled_input_masks.append(mask) - return filled_input_masks - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = None - output_mask = None - - if cls.check_concat(node, graph): - input_masks = cls.fill_input_masks(node, graph) - if input_masks: - output_mask = torch.cat(input_masks) - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask - - @PT_PRUNING_OPERATOR_METATYPES.register('elementwise') class PTElementwise(OpElementwise): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] @@ -347,6 +296,48 @@ class PTConcat(OpConcat): StopMaskForwardOp = PTStopMaskForwardOps InputOp = PTInput + @classmethod + def fill_input_masks(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[torch.Tensor], None]: + """ + Fill input masks with all None replaced by identity masks. + If all input masks is None return None. + + :param node: Node to determine it's sources. + :param graph: NNCF graph to work with. + :return: Filled input masks. + """ + input_edges = graph.get_input_edges(node) + previous_nodes = [edge.from_node for edge in input_edges] + input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] + + if all(mask is None for mask in input_masks): + return None + + device = [m for m in input_masks if m is not None][0].device + + filled_input_masks = [] + for i, mask in enumerate(input_masks): + if mask is None: + mask = torch.ones(input_edges[i].tensor_shape[1], device=device) + filled_input_masks.append(mask) + return filled_input_masks + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + input_masks = None + output_mask = None + + if cls.check_concat(node, graph): + input_masks = cls.fill_input_masks(node, graph) + if input_masks: + output_mask = torch.cat(input_masks) + + node.data['input_masks'] = input_masks + node.data['output_mask'] = output_mask + ConvolutionOp = PTConvolution + StopMaskForwardOp = PTStopMaskForwardOps + InputOp = PTInput + class ModelPruner(MaskPropagationAlgorithm): def __init__(self, model: NNCFNetwork, graph: NNCFGraph, From e9375828d7bb215e76dd510b75ce8c85e23fe5ca Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 1 Oct 2021 11:15:45 +0300 Subject: [PATCH 03/19] Revert PTDefaultMetaOp --- nncf/torch/pruning/export_helpers.py | 53 +++++++++++++++++----------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index f470e8e60aa..d283ef41e05 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -12,21 +12,14 @@ """ from typing import Union from typing import List -from collections import Counter import torch -import numpy as np from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode -from nncf.common.pruning.export_helpers import DefaultMetaOp -from nncf.common.pruning.utils import is_grouped_conv -from nncf.common.pruning.utils import get_sources_of_node from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry -from nncf.common.pruning.mask_propagation import identity_mask_propagation from nncf.common.pruning.mask_propagation import get_input_masks from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm -from nncf.common.graph.layer_attributes import GroupNormLayerAttributes from nncf.torch.graph.operator_metatypes import ( AddMetatype, AvgPool2dMetatype, @@ -60,6 +53,7 @@ TanhMetatype, ) from nncf.common.pruning.export_helpers import ( + DefaultMetaOp, OpInput, OpOutput, OpIdentityMaskForwardOps, @@ -79,25 +73,47 @@ PT_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") +class PTDefaultMetaOp(DefaultMetaOp): + @classmethod + def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + """ + Prune node by input_masks (if masks is not none and operation support it). + + :param model: NNCF network. + :param node: Node from NNCF graph that will be prune. + :param graph: Graph of model. + """ + + @classmethod + def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): + """ + Prune node by output_mask (if mask is not none and operation support it). + + :param model: NNCF network. + :param node: Node from NNCF graph that will be prune. + :param graph: Graph of model. + """ + + @PT_PRUNING_OPERATOR_METATYPES.register('model_input') -class PTInput(OpInput): +class PTInput(PTDefaultMetaOp, OpInput): subtypes = [PTInputNoopMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('model_output') -class PTOutput(OpOutput): +class PTOutput(PTDefaultMetaOp, OpOutput): subtypes = [PTOutputNoopMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class PTIdentityMaskForwardOps(OpIdentityMaskForwardOps): +class PTIdentityMaskForwardOps(PTDefaultMetaOp, OpIdentityMaskForwardOps): subtypes = [HardTanhMetatype, TanhMetatype, RELUMetatype, PRELUMetatype, ELUMetatype, GELUMetatype, SigmoidMetatype, SoftmaxMetatype, AvgPool2dMetatype, MaxPool2dMetatype, DropoutMetatype] additional_types = ['h_sigmoid', 'h_swish', 'RELU'] @PT_PRUNING_OPERATOR_METATYPES.register('convolution') -class PTConvolution(OpConvolution): +class PTConvolution(PTDefaultMetaOp, OpConvolution): subtypes = [Conv1dMetatype, Conv2dMetatype, Conv3dMetatype] @classmethod @@ -151,7 +167,7 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class PTTransposeConvolution(OpTransposeConvolution): +class PTTransposeConvolution(PTDefaultMetaOp, OpTransposeConvolution): subtypes = [ConvTranspose2dMetatype, ConvTranspose3dMetatype] @classmethod @@ -198,7 +214,7 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class PTBatchNorm(OpBatchNorm): +class PTBatchNorm(PTDefaultMetaOp, OpBatchNorm): subtypes = [BatchNormMetatype] @classmethod @@ -224,7 +240,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('group_norm') -class GroupNorm(OpGroupNorm): +class GroupNorm(PTDefaultMetaOp, OpGroupNorm): subtypes = [GroupNormMetatype] @classmethod @@ -250,7 +266,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('elementwise') -class PTElementwise(OpElementwise): +class PTElementwise(PTDefaultMetaOp, OpElementwise): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] @classmethod @@ -284,12 +300,12 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class PTStopMaskForwardOps(OpStopMaskForwardOps): +class PTStopMaskForwardOps(PTDefaultMetaOp, OpStopMaskForwardOps): subtypes = [MeanMetatype, MaxMetatype, MinMetatype, LinearMetatype, MatMulMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('concat') -class PTConcat(OpConcat): +class PTConcat(PTDefaultMetaOp, OpConcat): subtypes = [CatMetatype] ConvolutionOp = PTConvolution @@ -334,9 +350,6 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['input_masks'] = input_masks node.data['output_mask'] = output_mask - ConvolutionOp = PTConvolution - StopMaskForwardOp = PTStopMaskForwardOps - InputOp = PTInput class ModelPruner(MaskPropagationAlgorithm): From 309689f7b0c019be6ae67903ba5e6725ab195da2 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Fri, 1 Oct 2021 17:36:44 +0300 Subject: [PATCH 04/19] WIP export helper tests --- nncf/common/graph/layer_attributes.py | 13 ++ nncf/common/pruning/export_helpers.py | 2 +- nncf/tensorflow/graph/converter.py | 15 ++ nncf/torch/pruning/export_helpers.py | 13 ++ tests/common/pruning/test_export_helpers.py | 198 ++++++++++++++++++++ tests/tensorflow/test_model_converter.py | 29 +++ 6 files changed, 269 insertions(+), 1 deletion(-) create mode 100644 tests/common/pruning/test_export_helpers.py diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index a46be241d41..f839a2946a2 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -29,6 +29,19 @@ class BaseLayerAttributes(ABC): """ +class MultipleInputLayerAttributes(BaseLayerAttributes): + """ + Represents a layer with multiple inputs. + """ + def __init__(self, + axis: int): + self.axis = axis + + def __eq__(self, other): + return isinstance(other, MultipleInputLayerAttributes) \ + and self.axis == other.axis + + class WeightedLayerAttributes(BaseLayerAttributes): """ Represents a layer with weights. diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index 88c490b01b5..0a1cfd05c23 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -181,7 +181,7 @@ def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: for input_node in graph.get_previous_nodes(node): # If input has mask -> it went from convolution (source of this node is a convolution) - if input_node.data.get('output_mask', None) is None: + if input_node.data.get('output_mask', None) is not None: continue source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() + diff --git a/nncf/tensorflow/graph/converter.py b/nncf/tensorflow/graph/converter.py index 4a57f09abb3..25fc7c770f3 100644 --- a/nncf/tensorflow/graph/converter.py +++ b/nncf/tensorflow/graph/converter.py @@ -27,10 +27,12 @@ from nncf.common.graph import NNCFNodeName from nncf.common.graph import OperatorMetatype from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.layer_attributes import Dtype from nncf.common.utils.logger import logger as nncf_logger from nncf.tensorflow.graph.metatypes.common import DECONV_LAYER_METATYPES from nncf.tensorflow.graph.metatypes.common import DEPTHWISE_CONV_LAYER_METATYPES +from nncf.tensorflow.graph.metatypes.common import LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS from nncf.tensorflow.graph.metatypes.common import GENERAL_CONV_LAYER_METATYPES from nncf.tensorflow.graph.metatypes.matcher import get_keras_layer_metatype from nncf.tensorflow.graph.metatypes.matcher import get_op_metatype @@ -513,6 +515,8 @@ def convert(self) -> NNCFGraph: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=True) elif metatype in GENERAL_CONV_LAYER_METATYPES: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=False) + elif metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + layer_attributes = _get_multiple_input_layer_attributes(layer) else: layer_attributes = None is_shared = len(self._layer_name_to_node_names[layer_name]) > 1 @@ -603,6 +607,8 @@ def convert(self) -> NNCFGraph: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=True) elif layer_metatype in GENERAL_CONV_LAYER_METATYPES: layer_attributes = _get_conv_layer_attributes(self._get_layer(layer_name), is_depthwise=False) + elif layer_metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + layer_attributes = _get_multiple_input_layer_attributes(model_layer) if layer_attributes is not None: attrs.update({NNCFGraph.LAYER_ATTRIBUTES: layer_attributes}) @@ -650,6 +656,15 @@ def convert(self) -> NNCFGraph: return nncf_graph +def _get_multiple_input_layer_attributes(layer: tf.keras.layers.Layer) -> MultipleInputLayerAttributes: + if hasattr(layer, 'axis'): + axis = layer.axis + else: + #TODO + axis = -1 + return MultipleInputLayerAttributes() + + def _get_conv_layer_attributes(layer: tf.keras.layers.Layer, is_depthwise: bool = False) -> ConvolutionLayerAttributes: channel_axis = get_input_channel_axis(layer) layer_ = unwrap_layer(layer) diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index d283ef41e05..2f661f6f944 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -74,6 +74,19 @@ class PTDefaultMetaOp(DefaultMetaOp): + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + """ + Propagate mask through a node using masks of all inputs and pruning mask of current node (if any). + Should set the following attributes: + input_masks - list of masks of input nodes (None if there is no mask in some input); + output_mask - resulting mask of node operation. + + :param node: Node from NNCF graph to propagate mask through it. + :param graph: Graph of model to prune. + """ + raise NotImplementedError + @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): """ diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py new file mode 100644 index 00000000000..de092795ee0 --- /dev/null +++ b/tests/common/pruning/test_export_helpers.py @@ -0,0 +1,198 @@ +import numpy as np +import pytest + +from typing import List + +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.graph import NNCFGraph +from nncf.common.pruning.export_helpers import( +OpElementwise, +OpConvolution, + OpConcat, +OpStopMaskForwardOps, + +) + + +TEST_CASES = [ + ['flatten', (1, 1, 64), (1,)], + ['flatten', (1, 32, 64), (1,)], + ['reshape', (1, 32, 64), (1,)], # Flatten + ['reshape', (1, 1, 64), (1, 1, 1, 64)], # Expand + ['reshape', (1, 1, 1, 64), (1, 64)], # Squeeze + ['reshape', (1, 1, 1, 64), (1, 1, 64, 1)],# Transpose + ['reshape', (1, 1, 32, 64), (1, 64, 32)],# Transpose + ['reshape', (1, 1, 32, 64), (1, 64, 16, 16)], +] + +REF_ACCEPT_PRUNED = [ + True, + True, + True, + True, + True, + True, + True, + False +] + + +class DummyInputMetatype(OperatorMetatype): + @classmethod + def get_all_aliases(cls) -> List[str]: + return ['input'] + + +class DummyElementwise(OperatorMetatype): + @classmethod + def get_all_aliases(cls) -> List[str]: + return ['elementwise'] + + +class DummyStopPropOp(OperatorMetatype): + @classmethod + def get_all_aliases(cls) -> List[str]: + return ['stop_prop_op'] + + +class DummyConvMetatype(OperatorMetatype): + @classmethod + def get_all_aliases(cls) -> List[str]: + return ['conv'] + + +class DummyConcatMetatype(OperatorMetatype): + @classmethod + def get_all_aliases(cls) -> List[str]: + return ['concat'] + + +class DummyOpInput(OpConcat): + additional_types = ['input'] + + +class DummyOpStopMaskForward(OpStopMaskForwardOps): + additional_types = ['stop_prop_op'] + + +class DummyOpConv(OpConvolution): + additional_types = ['conv'] + + +class DummyOpElementwise(OpElementwise): + additional_types = ['elementwise'] + + +class DummyOpConcat(OpConcat): + ConvolutionOp = DummyOpConv + StopMaskForwardOp = DummyOpStopMaskForward + InputOp = DummyOpInput + additional_types = ['concat'] + + +def test_stop_ops_elementwise_source_before_concat(): + graph = NNCFGraph() + stop_op_0 = graph.add_nncf_node('stop_op_0', 'stop_prop_op', DummyStopPropOp) + stop_op_1 = graph.add_nncf_node('stop_op_1', 'stop_prop_op', DummyStopPropOp) + elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', DummyElementwise) + concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype) + + # stop_op_0 -> elementwise_node + graph.add_edge_between_nncf_nodes(from_node_id=stop_op_0.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10, 10], + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # stop_op_1 -> elementwise_node + graph.add_edge_between_nncf_nodes(from_node_id=stop_op_1.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10, 10], + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # elementwise_node -> concat_node + graph.add_edge_between_nncf_nodes(from_node_id=elementwise_node.node_id, + to_node_id=concat_node.node_id, + tensor_shape=[10, 10], + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + + assert not DummyOpConcat.check_concat(concat_node, graph) + DummyOpConcat.mask_propagation(concat_node, graph) + assert concat_node.data['output_mask'] is None + + +def test_convs_elementwise_source_before_concat(): + graph = NNCFGraph() + conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv', DummyConvMetatype) + conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv', DummyConvMetatype) + elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', DummyElementwise) + concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype) + + # conv_op_0 -> elementwise_node + graph.add_edge_between_nncf_nodes(from_node_id=conv_op_0.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # conv_op_1 -> elementwise_node + graph.add_edge_between_nncf_nodes(from_node_id=conv_op_1.node_id, + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # elementwise_node -> concat_node + graph.add_edge_between_nncf_nodes(from_node_id=elementwise_node.node_id, + to_node_id=concat_node.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + + # Check without masks + assert DummyOpConcat.check_concat(concat_node, graph) + # Set masks + conv_op_0 = graph.get_node_by_id(conv_op_0.node_id) + conv_op_1 = graph.get_node_by_id(conv_op_1.node_id) + elementwise_node = graph.get_node_by_id(elementwise_node.node_id) + conv_op_0.data['output_mask'] = np.ones(10) + conv_op_1.data['output_mask'] = np.ones(10) + # Propagate masks + DummyOpElementwise.mask_propagation(elementwise_node, graph) + # Check with masks + assert DummyOpConcat.check_concat(concat_node, graph) + DummyOpConcat.mask_propagation(concat_node, graph) + reference_mask = [] + assert concat_node.data['output_mask'] is None +#@pytest.mark.parametrize(('node_type', 'input_shape', 'output_shape', 'output_mask', 'output_mask_ref'), +# [input + ref for input, ref in zip(TEST_CASES, REF_OUTPUT_MASK)]) +#def test_reshape_metatype_mask_prop(node_type, input_shape, output_shape, output_mask, output_mask_ref): +# node_name = 'dummy_reshape' +# layer_attributes = ReshapeLayerAttributes(input_shape, output_shape) +# +# graph = NNCFGraph() +# prev_node = graph.add_nncf_node('prev_node', 'linear', DummyLinearMetatype) +# reshape_node = graph.add_nncf_node(node_name, node_type, ReshapeMetatype, layer_attributes=layer_attributes) +# +# graph.add_edge_between_nncf_nodes(from_node_id=prev_node.node_id, +# to_node_id=reshape_node.node_id, +# tensor_shape=output_shape, +# input_port_id=0, +# output_port_id=0, +# dtype=Dtype.FLOAT) +# # Get reference to graph node +# prev_node = graph.get_node_by_id(prev_node.node_id) +# reshape_node = graph.get_node_by_id(reshape_node.node_id) +# prev_node.data['output_mask'] = output_mask +# if output_mask_ref == 'error': +# with pytest.raises(AssertionError): +# PTReshape.mask_propagation(reshape_node, graph) +# else: +# PTReshape.mask_propagation(reshape_node, graph) +# assert torch.all(reshape_node.data['output_mask'] == output_mask_ref) +# \ No newline at end of file diff --git a/tests/tensorflow/test_model_converter.py b/tests/tensorflow/test_model_converter.py index e217b2df435..a0b47f1eaac 100644 --- a/tests/tensorflow/test_model_converter.py +++ b/tests/tensorflow/test_model_converter.py @@ -84,3 +84,32 @@ def test_get_custom_layers(): assert len(custom_layers) == 1 assert CustomLayerForTest.CUSTOM_LAYER_NAME in custom_layers assert isinstance(custom_layers[CustomLayerForTest.CUSTOM_LAYER_NAME], CustomLayerForTest) + + +def ModelWithReshapes(): + input =layers.Input((64, )) + x = tf.reshape(input, (32, -1)) + x = layers.Reshape((16, -1))(x) + ones = tf.ones_like(x) + t1 = layers.concatenate([x, ones]) + t2 = tf.concat([x, ones], axis=-1) + y = tf.concat([t1, t2], axis=-1) + return models.Model(input, y, name='ModelWithReshape') + + +def test_model_with_reshape_and_concat(): + model = ModelWithReshapes() + model.build((64,)) + graph = convert_keras_model_to_nncf_graph(model) + ref_reshape_nodes = {'tf_op_layer_Reshape': {'input_shape': (None, 64), + 'output_shape': (32, None)}, + 'reshape': {'input_shape': (32, None), + 'output_shape': (32, 8, 8, None)}, + 'flatten': {'input_shape': (32, 8, 8, None), + 'output_shape': (32, None)}} + for node in graph.get_all_nodes(): + if node.metatype in RESHAPE_METATYPES: + assert node.node_name in ref_reshape_nodes + assert node.layer_attributes is not None + assert node.layer_attributes.input_shape == ref_reshape_nodes[node.node_name]['input_shape'] + assert node.layer_attributes.output_shape == ref_reshape_nodes[node.node_name]['output_shape'] From e14a3ca938e4e80263ff00735698eb59214f5b97 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Mon, 4 Oct 2021 12:32:29 +0300 Subject: [PATCH 05/19] Add axis attribute for concat layers --- nncf/common/graph/graph.py | 4 +++ nncf/tensorflow/graph/converter.py | 16 +++++++++--- nncf/torch/dynamic_graph/wrappers.py | 2 +- nncf/torch/graph/graph_builder.py | 22 ++++++++++++++++ nncf/torch/layers.py | 1 + tests/tensorflow/test_model_converter.py | 29 +++++++++++---------- tests/torch/test_graph_building.py | 33 ++++++++++++++++++++++++ 7 files changed, 89 insertions(+), 18 deletions(-) diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index f4938df20de..891b50d6c8e 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -64,6 +64,10 @@ def layer_name(self) -> LayerName: def layer_attributes(self) -> BaseLayerAttributes: return self.data.get(NNCFGraph.LAYER_ATTRIBUTES) + @layer_attributes.setter + def layer_attributes(self, data) -> None: + self.data[NNCFGraph.LAYER_ATTRIBUTES] = data + @property def ignored_algorithms(self) -> List[str]: return self.data.get(NNCFGraph.IGNORED_ALGOS_ATTR, []) diff --git a/nncf/tensorflow/graph/converter.py b/nncf/tensorflow/graph/converter.py index 25fc7c770f3..145b93d8246 100644 --- a/nncf/tensorflow/graph/converter.py +++ b/nncf/tensorflow/graph/converter.py @@ -660,9 +660,19 @@ def _get_multiple_input_layer_attributes(layer: tf.keras.layers.Layer) -> Multip if hasattr(layer, 'axis'): axis = layer.axis else: - #TODO - axis = -1 - return MultipleInputLayerAttributes() + input_shape = layer.input_shape + output_shape = layer.output_shape + axis = None + # If it's dummy concat of one tensor + if len(input_shape) == 1: + axis = -1 + for idx, (dim_in, dim_out) in enumerate(zip(input_shape[0], output_shape[0])): + if dim_in is None or dim_in != dim_out: + axis = idx + break + if axis is None: + raise RuntimeError('Unexpected behaviour for concat op') + return MultipleInputLayerAttributes(axis) def _get_conv_layer_attributes(layer: tf.keras.layers.Layer, is_depthwise: bool = False) -> ConvolutionLayerAttributes: diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 238020122b6..1b0e02f2ea1 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -55,7 +55,7 @@ def ignore_scope(cls): return cls -OP_NAMES_REQUIRING_MODULE_ATTRS = [v.op_func_name for v in NNCF_MODULES_DICT] + ["group_norm"] +OP_NAMES_REQUIRING_MODULE_ATTRS = [v.op_func_name for v in NNCF_MODULES_DICT] + ['group_norm'] def wrap_operator(operator, operator_info: 'PatchedOperatorInfo'): diff --git a/nncf/torch/graph/graph_builder.py b/nncf/torch/graph/graph_builder.py index c64b5cbf3ed..852a31f96b6 100644 --- a/nncf/torch/graph/graph_builder.py +++ b/nncf/torch/graph/graph_builder.py @@ -22,10 +22,12 @@ from nncf.common.graph import INPUT_NOOP_METATYPES from nncf.common.graph import LayerName from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph_tracer import GraphTracer from nncf.torch.dynamic_graph.graph_tracer import ModelInputInfo from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.operator_metatypes import CatMetatype from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES @@ -91,4 +93,24 @@ def convert(dynamic_graph: DynamicGraph, input_infos: List[ModelInputInfo] = Non output_port_id=dynamic_graph_edge.output_port_id, dtype=Dtype.FLOAT ) + + for node in nncf_graph.get_all_nodes(): + if node.metatype is CatMetatype: + input_edges = nncf_graph.get_input_edges(node) + output_edges = nncf_graph.get_output_edges(node) + # In case is intermediate node + if input_edges and output_edges: + axis = None + if len(input_edges) == 1: + axis = -1 + input_shape = input_edges[0].tensor_shape + output_shape = output_edges[0].tensor_shape + for idx, (dim_in, dim_out) in enumerate(zip(input_shape, output_shape)): + if dim_in is None or dim_in != dim_out: + axis = idx + break + if axis is None: + raise RuntimeError('Unexpected behaviour for concat op') + layer_attributes = MultipleInputLayerAttributes(axis) + node.layer_attributes = layer_attributes return nncf_graph diff --git a/nncf/torch/layers.py b/nncf/torch/layers.py index 47cb79de46a..fc27f93edb5 100644 --- a/nncf/torch/layers.py +++ b/nncf/torch/layers.py @@ -201,6 +201,7 @@ def from_module(module): dict_update(nncf_embedding_bag.__dict__, module.__dict__) return nncf_embedding_bag + NNCF_MODULES_DICT = { NNCFConv1d: nn.Conv1d, NNCFConv2d: nn.Conv2d, diff --git a/tests/tensorflow/test_model_converter.py b/tests/tensorflow/test_model_converter.py index a0b47f1eaac..2264c180ee7 100644 --- a/tests/tensorflow/test_model_converter.py +++ b/tests/tensorflow/test_model_converter.py @@ -11,6 +11,7 @@ limitations under the License. """ +import pytest import tensorflow as tf from tensorflow.python.keras import backend from tensorflow.python.keras import layers @@ -19,8 +20,10 @@ from nncf.common.graph import INPUT_NOOP_METATYPES from nncf.common.graph import OUTPUT_NOOP_METATYPES +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.tensorflow.graph.converter import TFModelConverter from nncf.tensorflow.graph.converter import convert_keras_model_to_nncf_graph +from nncf.tensorflow.graph.metatypes.common import LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS from tests.tensorflow.helpers import get_basic_conv_test_model from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.quantization.test_algorithm_quantization import get_basic_quantization_config @@ -86,8 +89,8 @@ def test_get_custom_layers(): assert isinstance(custom_layers[CustomLayerForTest.CUSTOM_LAYER_NAME], CustomLayerForTest) -def ModelWithReshapes(): - input =layers.Input((64, )) +def ModelWithReshapesAndConcats(batch_size=None): + input =layers.Input((64, ), batch_size=batch_size) x = tf.reshape(input, (32, -1)) x = layers.Reshape((16, -1))(x) ones = tf.ones_like(x) @@ -97,19 +100,17 @@ def ModelWithReshapes(): return models.Model(input, y, name='ModelWithReshape') -def test_model_with_reshape_and_concat(): - model = ModelWithReshapes() +@pytest.mark.parametrize('batch_size', [None, 8], ids=['no_batch_size', 'with_batch_size']) +def test_model_with_reshape_and_concat(batch_size): + model = ModelWithReshapesAndConcats(batch_size) model.build((64,)) graph = convert_keras_model_to_nncf_graph(model) - ref_reshape_nodes = {'tf_op_layer_Reshape': {'input_shape': (None, 64), - 'output_shape': (32, None)}, - 'reshape': {'input_shape': (32, None), - 'output_shape': (32, 8, 8, None)}, - 'flatten': {'input_shape': (32, 8, 8, None), - 'output_shape': (32, None)}} + ref_concat_nodes = {'concatenate': {'axis': [-1, 2]}, + 'tf_op_layer_concat': {'axis': [-1, 2]}, + 'tf_op_layer_concat_1': {'axis': [-1, 2]}} for node in graph.get_all_nodes(): - if node.metatype in RESHAPE_METATYPES: - assert node.node_name in ref_reshape_nodes + if node.metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: + assert node.node_name in ref_concat_nodes assert node.layer_attributes is not None - assert node.layer_attributes.input_shape == ref_reshape_nodes[node.node_name]['input_shape'] - assert node.layer_attributes.output_shape == ref_reshape_nodes[node.node_name]['output_shape'] + assert isinstance(node.layer_attributes, MultipleInputLayerAttributes) + assert node.layer_attributes.axis in ref_concat_nodes[node.node_name]['axis'] diff --git a/tests/torch/test_graph_building.py b/tests/torch/test_graph_building.py index 4eff40b434f..51e401f9e83 100644 --- a/tests/torch/test_graph_building.py +++ b/tests/torch/test_graph_building.py @@ -15,6 +15,7 @@ from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME from nncf.common.graph.definitions import MODEL_OUTPUT_OP_NAME from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from typing import List from typing import Tuple @@ -201,6 +202,38 @@ def test_activation_shape_tracing(input_shape: Tuple): assert output_tensor_shapes == ref_output_shapes, "Failed for node ID: {}".format(node_id) +class ModelForTestWithReshapeFlattenAndConcat(ModelForTest): + def forward(self, x): + y = super().forward(x) + size = y.size() + y = y.view(size + (1, 1)) + y_copy = torch.ones_like(y) + y = torch.cat([y, y_copy], -1) + y = torch.flatten(y) + _ = y.view(-1) + y_copy = torch.ones_like(y) + y = torch.cat([y, y_copy], -1) + return y + + +@pytest.mark.parametrize("input_shape", input_shapes) +def test_concat_attributes_saved_during_graph_building(input_shape): + model = ModelForTestWithReshapeFlattenAndConcat() + input_info = ModelInputInfo(input_shape) + graph_builder = GraphBuilder(create_dummy_forward_fn([input_info, ], with_input_tracing=True, + with_output_tracing=True)) + graph = graph_builder.build_graph(model) + reshape_nodes_with_attributes = { + 'ModelForTestWithReshapeFlattenAndConcat/cat_0': {'axis': 1}, + 'ModelForTestWithReshapeFlattenAndConcat/cat_1': {'axis': 5}, + 'ModelForTestWithReshapeFlattenAndConcat/cat_2': {'axis': 0} + } + for name, shapes in reshape_nodes_with_attributes.items(): + node = graph.get_node_by_name(name) + assert isinstance(node.layer_attributes, MultipleInputLayerAttributes) + assert node.layer_attributes.axis == shapes['axis'] + + TEST_KEYWORD_1 = "keyword1" TEST_KEYWORD_2 = "keyword2" INPUT_INFO_CONFIG_VS_FORWARD_ARGS = [ From c0459af26c9c8268fd401bc5ad1bdfdf92879239 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Mon, 4 Oct 2021 13:36:56 +0300 Subject: [PATCH 06/19] Add concat tests for common implementation --- nncf/common/pruning/export_helpers.py | 5 +- tests/common/pruning/test_export_helpers.py | 198 +++++++++----------- 2 files changed, 90 insertions(+), 113 deletions(-) diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index 0a1cfd05c23..ca87511f944 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -187,7 +187,7 @@ def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() + cls.StopMaskForwardOp.get_all_op_aliases() + cls.InputOp.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] + sources_types = [node.node_type for node in source_nodes] + [input_node.node_type] if any(t in sources_types for t in cls.StopMaskForwardOp.get_all_op_aliases()): return False return True @@ -213,7 +213,8 @@ def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[np.arra filled_input_masks = [] for i, mask in enumerate(input_masks): if mask is None: - mask = np.ones(input_edges[i].tensor_shape[-1]) + concat_axis = node.layer_attributes.axis + mask = np.ones(input_edges[i].tensor_shape[concat_axis]) filled_input_masks.append(mask) result_mask = np.concatenate(filled_input_masks, 0) return result_mask diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py index de092795ee0..46566903cb4 100644 --- a/tests/common/pruning/test_export_helpers.py +++ b/tests/common/pruning/test_export_helpers.py @@ -2,42 +2,22 @@ import pytest from typing import List +from functools import partial from nncf.common.graph.layer_attributes import Dtype from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.graph import NNCFGraph -from nncf.common.pruning.export_helpers import( -OpElementwise, -OpConvolution, - OpConcat, -OpStopMaskForwardOps, - +from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm +from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.export_helpers import ( + OpElementwise, + OpConvolution, + OpConcat, + OpStopMaskForwardOps, ) -TEST_CASES = [ - ['flatten', (1, 1, 64), (1,)], - ['flatten', (1, 32, 64), (1,)], - ['reshape', (1, 32, 64), (1,)], # Flatten - ['reshape', (1, 1, 64), (1, 1, 1, 64)], # Expand - ['reshape', (1, 1, 1, 64), (1, 64)], # Squeeze - ['reshape', (1, 1, 1, 64), (1, 1, 64, 1)],# Transpose - ['reshape', (1, 1, 32, 64), (1, 64, 32)],# Transpose - ['reshape', (1, 1, 32, 64), (1, 64, 16, 16)], -] - -REF_ACCEPT_PRUNED = [ - True, - True, - True, - True, - True, - True, - True, - False -] - - class DummyInputMetatype(OperatorMetatype): @classmethod def get_all_aliases(cls) -> List[str]: @@ -68,22 +48,30 @@ def get_all_aliases(cls) -> List[str]: return ['concat'] +DUMMY_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register('input') class DummyOpInput(OpConcat): additional_types = ['input'] +@DUMMY_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') class DummyOpStopMaskForward(OpStopMaskForwardOps): additional_types = ['stop_prop_op'] +@DUMMY_PRUNING_OPERATOR_METATYPES.register('conv') class DummyOpConv(OpConvolution): additional_types = ['conv'] +@DUMMY_PRUNING_OPERATOR_METATYPES.register('elementwise') class DummyOpElementwise(OpElementwise): additional_types = ['elementwise'] +@DUMMY_PRUNING_OPERATOR_METATYPES.register('concat') class DummyOpConcat(OpConcat): ConvolutionOp = DummyOpConv StopMaskForwardOp = DummyOpStopMaskForward @@ -91,108 +79,96 @@ class DummyOpConcat(OpConcat): additional_types = ['concat'] -def test_stop_ops_elementwise_source_before_concat(): +@pytest.mark.parametrize('with_elementwise', [False, True]) +def test_stop_ops_elementwise_source_before_concat(with_elementwise): graph = NNCFGraph() stop_op_0 = graph.add_nncf_node('stop_op_0', 'stop_prop_op', DummyStopPropOp) stop_op_1 = graph.add_nncf_node('stop_op_1', 'stop_prop_op', DummyStopPropOp) - elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', DummyElementwise) - concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype) - - # stop_op_0 -> elementwise_node - graph.add_edge_between_nncf_nodes(from_node_id=stop_op_0.node_id, - to_node_id=elementwise_node.node_id, - tensor_shape=[10, 10], - input_port_id=0, - output_port_id=0, - dtype=Dtype.FLOAT) - # stop_op_1 -> elementwise_node - graph.add_edge_between_nncf_nodes(from_node_id=stop_op_1.node_id, - to_node_id=elementwise_node.node_id, - tensor_shape=[10, 10], - input_port_id=0, - output_port_id=0, - dtype=Dtype.FLOAT) - # elementwise_node -> concat_node - graph.add_edge_between_nncf_nodes(from_node_id=elementwise_node.node_id, - to_node_id=concat_node.node_id, - tensor_shape=[10, 10], - input_port_id=0, - output_port_id=0, - dtype=Dtype.FLOAT) + concat_layer_attributes = MultipleInputLayerAttributes(-1) + concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype, + layer_attributes=concat_layer_attributes) + add_node = partial(graph.add_edge_between_nncf_nodes, + tensor_shape=[10, 10], + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + + if not with_elementwise: + # stop_op_0 -> concat_node + add_node(from_node_id=stop_op_0.node_id, + to_node_id=concat_node.node_id) + + # stop_op_1 -> concat_node + add_node(from_node_id=stop_op_1.node_id, + to_node_id=concat_node.node_id) + else: + elementwise_op = graph.add_nncf_node('elementwise', 'elementwise', DummyElementwise) + + # stop_op_0 -> elementwise + add_node(from_node_id=stop_op_0.node_id, + to_node_id=elementwise_op.node_id) + + # stop_op_1 -> elementwise + add_node(from_node_id=stop_op_1.node_id, + to_node_id=elementwise_op.node_id) + + # elementwise -> concat + add_node(from_node_id=elementwise_op.node_id, + to_node_id=concat_node.node_id) assert not DummyOpConcat.check_concat(concat_node, graph) - DummyOpConcat.mask_propagation(concat_node, graph) + MaskPropagationAlgorithm(graph, DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + concat_node = graph.get_node_by_id(concat_node.node_id) assert concat_node.data['output_mask'] is None -def test_convs_elementwise_source_before_concat(): +@pytest.mark.parametrize('empty_mask_branch', [False, True]) +def test_convs_elementwise_source_before_concat(empty_mask_branch): graph = NNCFGraph() conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv', DummyConvMetatype) conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv', DummyConvMetatype) + conv_op_2 = graph.add_nncf_node('conv_op_2', 'conv', DummyConvMetatype) elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', DummyElementwise) - concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype) + concat_layer_attributes = MultipleInputLayerAttributes(2) + concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype, + layer_attributes=concat_layer_attributes) + add_node = partial(graph.add_edge_between_nncf_nodes, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) # conv_op_0 -> elementwise_node - graph.add_edge_between_nncf_nodes(from_node_id=conv_op_0.node_id, - to_node_id=elementwise_node.node_id, - tensor_shape=[10] * 4, - input_port_id=0, - output_port_id=0, - dtype=Dtype.FLOAT) + add_node(from_node_id=conv_op_0.node_id, + to_node_id=elementwise_node.node_id) + # conv_op_1 -> elementwise_node - graph.add_edge_between_nncf_nodes(from_node_id=conv_op_1.node_id, - to_node_id=elementwise_node.node_id, - tensor_shape=[10] * 4, - input_port_id=0, - output_port_id=0, - dtype=Dtype.FLOAT) + add_node(from_node_id=conv_op_1.node_id, + to_node_id=elementwise_node.node_id) + # elementwise_node -> concat_node - graph.add_edge_between_nncf_nodes(from_node_id=elementwise_node.node_id, - to_node_id=concat_node.node_id, - tensor_shape=[10] * 4, - input_port_id=0, - output_port_id=0, - dtype=Dtype.FLOAT) + add_node(from_node_id=elementwise_node.node_id, + to_node_id=concat_node.node_id) + + # conv_op_2 -> concat_node + add_node(from_node_id=conv_op_2.node_id, + to_node_id=concat_node.node_id) # Check without masks assert DummyOpConcat.check_concat(concat_node, graph) # Set masks - conv_op_0 = graph.get_node_by_id(conv_op_0.node_id) - conv_op_1 = graph.get_node_by_id(conv_op_1.node_id) - elementwise_node = graph.get_node_by_id(elementwise_node.node_id) - conv_op_0.data['output_mask'] = np.ones(10) - conv_op_1.data['output_mask'] = np.ones(10) + masked_convs = [conv_op_0, conv_op_1] + if not empty_mask_branch: + masked_convs.append(conv_op_2) + + for conv_op in masked_convs: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = np.ones(10) + # Propagate masks - DummyOpElementwise.mask_propagation(elementwise_node, graph) + MaskPropagationAlgorithm(graph, DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() # Check with masks + concat_node = graph.get_node_by_id(concat_node.node_id) assert DummyOpConcat.check_concat(concat_node, graph) - DummyOpConcat.mask_propagation(concat_node, graph) - reference_mask = [] - assert concat_node.data['output_mask'] is None -#@pytest.mark.parametrize(('node_type', 'input_shape', 'output_shape', 'output_mask', 'output_mask_ref'), -# [input + ref for input, ref in zip(TEST_CASES, REF_OUTPUT_MASK)]) -#def test_reshape_metatype_mask_prop(node_type, input_shape, output_shape, output_mask, output_mask_ref): -# node_name = 'dummy_reshape' -# layer_attributes = ReshapeLayerAttributes(input_shape, output_shape) -# -# graph = NNCFGraph() -# prev_node = graph.add_nncf_node('prev_node', 'linear', DummyLinearMetatype) -# reshape_node = graph.add_nncf_node(node_name, node_type, ReshapeMetatype, layer_attributes=layer_attributes) -# -# graph.add_edge_between_nncf_nodes(from_node_id=prev_node.node_id, -# to_node_id=reshape_node.node_id, -# tensor_shape=output_shape, -# input_port_id=0, -# output_port_id=0, -# dtype=Dtype.FLOAT) -# # Get reference to graph node -# prev_node = graph.get_node_by_id(prev_node.node_id) -# reshape_node = graph.get_node_by_id(reshape_node.node_id) -# prev_node.data['output_mask'] = output_mask -# if output_mask_ref == 'error': -# with pytest.raises(AssertionError): -# PTReshape.mask_propagation(reshape_node, graph) -# else: -# PTReshape.mask_propagation(reshape_node, graph) -# assert torch.all(reshape_node.data['output_mask'] == output_mask_ref) -# \ No newline at end of file + reference_mask = np.ones((20,)) + np.testing.assert_equal(concat_node.data['output_mask'], reference_mask) From 27c3ba29066f4c42b8feb0d96aaa7943f4f9c9c2 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Mon, 4 Oct 2021 19:07:16 +0300 Subject: [PATCH 07/19] common export_helpers tests --- tests/common/pruning/dummy_types.py | 120 ++++++++++ tests/common/pruning/test_export_helpers.py | 249 ++++++++++++++------ 2 files changed, 302 insertions(+), 67 deletions(-) create mode 100644 tests/common/pruning/dummy_types.py diff --git a/tests/common/pruning/dummy_types.py b/tests/common/pruning/dummy_types.py new file mode 100644 index 00000000000..2eff563a64a --- /dev/null +++ b/tests/common/pruning/dummy_types.py @@ -0,0 +1,120 @@ +from typing import List + +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.export_helpers import ( + OpInput, + OpOutput, + OpIdentityMaskForwardOps, + OpConvolution, + OpTransposeConvolution, + OpBatchNorm, + OpGroupNorm, + OpConcat, + OpElementwise, + OpStopMaskForwardOps, +) + + +class DummyDefaultMetatype(OperatorMetatype): + name = None + + @classmethod + def get_all_aliases(cls) -> List[str]: + return [cls.name] + + +class DummyInputMetatype(OperatorMetatype): + name = 'input' + + +class DummyOutputMetatype(OperatorMetatype): + name = 'output' + + +class DymmyIdentityMaskForwardMetatype(OperatorMetatype): + name = 'identity_mask_forward' + + +class DummyElementwiseMetatype(OperatorMetatype): + name = 'elementwise' + + +class DummyConvMetatype(OperatorMetatype): + name = 'conv' + + +class DummyTransposeConvolutionMetatype(OperatorMetatype): + name = 'transpose_conv' + + +class DummyBatchNormMetatype(OperatorMetatype): + name = 'batch_norm' + + +class DummyGroupNormMetatype(OperatorMetatype): + name = 'group_norm' + + +class DummyConcatMetatype(OperatorMetatype): + name = 'concat' + + +class DummyStopPropoagtionMetatype(OperatorMetatype): + name = 'stop_propagation_ops' + + +DUMMY_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyInputMetatype.name) +class DummyOpInput(OpInput): + additional_types = [DummyInputMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyOutputMetatype.name) +class DummyOpOutput(OpOutput): + additional_types = [DummyOutputMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DymmyIdentityMaskForwardMetatype.name) +class DummyOpIdentityMaskForward(OpIdentityMaskForwardOps): + additional_types = [DymmyIdentityMaskForwardMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyStopPropoagtionMetatype.name) +class DummyOpStopMaskForward(OpStopMaskForwardOps): + additional_types = [DummyStopPropoagtionMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConvMetatype.name) +class DummyOpConv(OpConvolution): + additional_types = [DummyConvMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyTransposeConvolutionMetatype.name) +class DummyOpTransposeConv(OpTransposeConvolution): + additional_types = [DummyTransposeConvolutionMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyBatchNormMetatype.name) +class DummyOpBatchNorm(OpBatchNorm): + additional_types = [DummyBatchNormMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyGroupNormMetatype.name) +class DummyOpGroupNorm(OpGroupNorm): + additional_types = [DummyGroupNormMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyElementwiseMetatype.name) +class DummyOpElementwise(OpElementwise): + additional_types = [DummyElementwiseMetatype.name] + + +@DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConcatMetatype.name) +class DummyOpConcat(OpConcat): + ConvolutionOp = DummyOpConv + StopMaskForwardOp = DummyOpStopMaskForward + InputOp = DummyOpInput + additional_types = [DummyConcatMetatype.name] diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py index 46566903cb4..ddb0f6a8380 100644 --- a/tests/common/pruning/test_export_helpers.py +++ b/tests/common/pruning/test_export_helpers.py @@ -1,91 +1,206 @@ import numpy as np import pytest -from typing import List from functools import partial +import tests.common.pruning.dummy_types as dummy_types + +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.layer_attributes import Dtype -from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes -from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.layer_attributes import GroupNormLayerAttributes +from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes +from nncf.common.pruning.export_helpers import DefaultMetaOp from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm -from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry -from nncf.common.pruning.export_helpers import ( - OpElementwise, - OpConvolution, - OpConcat, - OpStopMaskForwardOps, -) -class DummyInputMetatype(OperatorMetatype): - @classmethod - def get_all_aliases(cls) -> List[str]: - return ['input'] - - -class DummyElementwise(OperatorMetatype): - @classmethod - def get_all_aliases(cls) -> List[str]: - return ['elementwise'] +@pytest.mark.parametrize('dummy_op_class,accept_pruned_input', [(dummy_types.DummyOpInput, False), + (dummy_types.DummyOpOutput, True), + (dummy_types.DummyOpStopMaskForward, False)]) +def test_stop_propagate_ops(dummy_op_class, accept_pruned_input): + node = NNCFNode(0, 'dummy_node') + assert dummy_op_class.accept_pruned_input(node) == accept_pruned_input + dummy_op_class.mask_propagation(node, None) + assert node.data['output_mask'] is None -class DummyStopPropOp(OperatorMetatype): - @classmethod - def get_all_aliases(cls) -> List[str]: - return ['stop_prop_op'] - - -class DummyConvMetatype(OperatorMetatype): - @classmethod - def get_all_aliases(cls) -> List[str]: - return ['conv'] +@pytest.mark.parametrize('dummy_op_class', [dummy_types.DummyOpIdentityMaskForward, dummy_types.DummyOpBatchNorm]) +def test_identity_mask_propogation_prune_ops(dummy_op_class): + assert dummy_op_class.accept_pruned_input(None) + graph = NNCFGraph() + conv_op = graph.add_nncf_node('conv_op', 'conv', dummy_types.DummyConvMetatype) + identity_ops = [] + for alias in dummy_op_class.get_all_op_aliases(): + identity_op = graph.add_nncf_node('identity', alias, dummy_types.DymmyIdentityMaskForwardMetatype) + graph.add_edge_between_nncf_nodes( + from_node_id=conv_op.node_id, + to_node_id=identity_op.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + identity_ops.append(identity_op) + # Check with and without masks + for output_mask in [None, np.ones((10,))]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + for identity_op in identity_ops: + identity_op = graph.get_node_by_id(identity_op.node_id) + assert np.all(identity_op.data['output_mask'] == output_mask) -class DummyConcatMetatype(OperatorMetatype): - @classmethod - def get_all_aliases(cls) -> List[str]: - return ['concat'] +@pytest.mark.parametrize('valid_masks', [None, True, False]) +def test_elementwise_prune_ops(valid_masks): + graph = NNCFGraph() + conv_op_0 = graph.add_nncf_node('conv_op_0', dummy_types.DummyConvMetatype.name, dummy_types.DummyConvMetatype) + conv_op_1 = graph.add_nncf_node('conv_op_1', dummy_types.DummyConvMetatype.name, dummy_types.DummyConvMetatype) + elementwise_op = graph.add_nncf_node('elementwise', dummy_types.DummyElementwiseMetatype.name, + dummy_types.DummyElementwiseMetatype) + add_node = partial(graph.add_edge_between_nncf_nodes, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # conv_op_0 -> elementwise + add_node(from_node_id=conv_op_0.node_id, + to_node_id=elementwise_op.node_id) + # conv_op_1 -> elementwise + add_node(from_node_id=conv_op_1.node_id, + to_node_id=elementwise_op.node_id) -DUMMY_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") + masks = [np.ones((10,)), np.ones((10,))] if valid_masks is not None else None + def set_masks(masks, ops): + for conv_op, mask in zip(ops, masks): + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = mask -@DUMMY_PRUNING_OPERATOR_METATYPES.register('input') -class DummyOpInput(OpConcat): - additional_types = ['input'] + if valid_masks is None or valid_masks: + if valid_masks: + set_masks(masks, [conv_op_0, conv_op_1]) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + elementwise_op = graph.get_node_by_id(elementwise_op.node_id) + assert np.all(elementwise_op.data['output_mask'] == masks) + else: + def check_wrong_masks(masks): + with pytest.raises(AssertionError): + set_masks(masks, [conv_op_0, conv_op_1]) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + masks[0][0] = 0 + check_wrong_masks(masks) + masks[0] = np.concatenate([masks[1], np.array([1])]) + check_wrong_masks(masks) -@DUMMY_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class DummyOpStopMaskForward(OpStopMaskForwardOps): - additional_types = ['stop_prop_op'] +@pytest.mark.parametrize('num_channels,num_groups,accept_pruned_input_ref', [(10, 10, True), + (10, 5, False), + (10, 1, False)]) +def test_group_norm_pruning_ops(num_channels, num_groups, accept_pruned_input_ref): + graph = NNCFGraph() + conv_op = graph.add_nncf_node('conv_op', 'conv', dummy_types.DummyConvMetatype) + group_norm_layer_attributes = GroupNormLayerAttributes(True, num_channels=num_channels, + num_groups=num_groups) + group_norm_op = graph.add_nncf_node('identity', dummy_types.DummyGroupNormMetatype.name, + dummy_types.DummyGroupNormMetatype, + layer_attributes=group_norm_layer_attributes) + assert dummy_types.DummyOpGroupNorm.accept_pruned_input(group_norm_op) == accept_pruned_input_ref + graph.add_edge_between_nncf_nodes( + from_node_id=conv_op.node_id, + to_node_id=group_norm_op.node_id, + tensor_shape=[10] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + # Check with and without masks + for output_mask in [None, np.ones((10,))]: + conv_op = graph.get_node_by_id(conv_op.node_id) + conv_op.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + identity_op = graph.get_node_by_id(group_norm_op.node_id) + assert np.all(identity_op.data['output_mask'] == output_mask) -@DUMMY_PRUNING_OPERATOR_METATYPES.register('conv') -class DummyOpConv(OpConvolution): - additional_types = ['conv'] +class DummyMaskProducerMetatype(dummy_types.DummyDefaultMetatype): + name = 'mask_producer' -@DUMMY_PRUNING_OPERATOR_METATYPES.register('elementwise') -class DummyOpElementwise(OpElementwise): - additional_types = ['elementwise'] +@dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyMaskProducerMetatype.name) +class MockOpMaskProducer(DefaultMetaOp): + additional_types = [DummyMaskProducerMetatype.name] + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + pass + + +@pytest.mark.parametrize('transpose', [True, False], ids=['transpose', 'not_transpose']) +@pytest.mark.parametrize('layer_attributes,ref_accept_pruned_input,conv_type', [ + ({'in_channels': 5, + 'out_channels': 10, + 'groups': 1}, True, 'usual_conv'), + ({'in_channels': 10, + 'out_channels': 20, + 'groups': 5}, False, 'grouped_conv_no_depthwise'), + ({'in_channels': 10, + 'out_channels': 20, + 'groups': 10}, True, 'depthwise_conv') +], + ids=['usual_conv', + 'grouped_conv_no_depthwise', + 'depthwise_conv'] +) +def test_conv_pruning_ops(transpose, layer_attributes, ref_accept_pruned_input, conv_type): + default_conv_params = { + 'weight_requires_grad': True, + 'kernel_size': (2, 2), + 'stride': (1, 1), + 'padding_values': [0, 0] + } + graph = NNCFGraph() + dummy_op_before = graph.add_nncf_node('dummy_op_before', DummyMaskProducerMetatype.name, + DummyMaskProducerMetatype) + target_conv_attributes = ConvolutionLayerAttributes(transpose=transpose, **layer_attributes, **default_conv_params) + conv_op_target = graph.add_nncf_node('conv_op_target', dummy_types.DummyConvMetatype.name, + dummy_types.DummyConvMetatype, + layer_attributes=target_conv_attributes) + graph.add_edge_between_nncf_nodes(from_node_id=dummy_op_before.node_id, + to_node_id=conv_op_target.node_id, + tensor_shape=[layer_attributes['in_channels']] * 4, + input_port_id=0, + output_port_id=0, + dtype=Dtype.FLOAT) + pruning_op_class = dummy_types.DummyOpTransposeConv if transpose else dummy_types.DummyOpConv + assert pruning_op_class.accept_pruned_input(conv_op_target) == ref_accept_pruned_input + ones_input_mask = np.ones((layer_attributes['in_channels'],)) + ones_output_mask = np.ones((layer_attributes['out_channels'],)) + for input_mask in [None, ones_input_mask]: + for output_mask in [None, ones_output_mask]: + dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id) + conv_op_target = graph.get_node_by_id(conv_op_target.node_id) + dummy_op_before.data['output_mask'] = input_mask + conv_op_target.data['output_mask'] = output_mask + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id) + conv_op_target = graph.get_node_by_id(conv_op_target.node_id) + if conv_type == 'usual_conv': + assert np.all(conv_op_target.data['output_mask'] == output_mask) + elif conv_type == 'grouped_conv_no_depthwise': + assert conv_op_target.data['output_mask'] is None + else: + assert np.all(conv_op_target.data['output_mask'] == input_mask) -@DUMMY_PRUNING_OPERATOR_METATYPES.register('concat') -class DummyOpConcat(OpConcat): - ConvolutionOp = DummyOpConv - StopMaskForwardOp = DummyOpStopMaskForward - InputOp = DummyOpInput - additional_types = ['concat'] @pytest.mark.parametrize('with_elementwise', [False, True]) def test_stop_ops_elementwise_source_before_concat(with_elementwise): graph = NNCFGraph() - stop_op_0 = graph.add_nncf_node('stop_op_0', 'stop_prop_op', DummyStopPropOp) - stop_op_1 = graph.add_nncf_node('stop_op_1', 'stop_prop_op', DummyStopPropOp) + stop_op_0 = graph.add_nncf_node('stop_op_0', 'stop_propagation_ops', dummy_types.DummyStopPropoagtionMetatype) + stop_op_1 = graph.add_nncf_node('stop_op_1', 'stop_propagation_ops', dummy_types.DummyStopPropoagtionMetatype) concat_layer_attributes = MultipleInputLayerAttributes(-1) - concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype, + concat_node = graph.add_nncf_node('concat_node', 'concat', dummy_types.DummyConcatMetatype, layer_attributes=concat_layer_attributes) add_node = partial(graph.add_edge_between_nncf_nodes, tensor_shape=[10, 10], @@ -102,7 +217,7 @@ def test_stop_ops_elementwise_source_before_concat(with_elementwise): add_node(from_node_id=stop_op_1.node_id, to_node_id=concat_node.node_id) else: - elementwise_op = graph.add_nncf_node('elementwise', 'elementwise', DummyElementwise) + elementwise_op = graph.add_nncf_node('elementwise', 'elementwise', dummy_types.DummyElementwiseMetatype) # stop_op_0 -> elementwise add_node(from_node_id=stop_op_0.node_id, @@ -116,8 +231,8 @@ def test_stop_ops_elementwise_source_before_concat(with_elementwise): add_node(from_node_id=elementwise_op.node_id, to_node_id=concat_node.node_id) - assert not DummyOpConcat.check_concat(concat_node, graph) - MaskPropagationAlgorithm(graph, DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + assert not dummy_types.DummyOpConcat.check_concat(concat_node, graph) + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() concat_node = graph.get_node_by_id(concat_node.node_id) assert concat_node.data['output_mask'] is None @@ -125,12 +240,12 @@ def test_stop_ops_elementwise_source_before_concat(with_elementwise): @pytest.mark.parametrize('empty_mask_branch', [False, True]) def test_convs_elementwise_source_before_concat(empty_mask_branch): graph = NNCFGraph() - conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv', DummyConvMetatype) - conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv', DummyConvMetatype) - conv_op_2 = graph.add_nncf_node('conv_op_2', 'conv', DummyConvMetatype) - elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', DummyElementwise) + conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv', dummy_types.DummyConvMetatype) + conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv', dummy_types.DummyConvMetatype) + conv_op_2 = graph.add_nncf_node('conv_op_2', 'conv', dummy_types.DummyConvMetatype) + elementwise_node = graph.add_nncf_node('elementwise_node', 'elementwise', dummy_types.DummyElementwiseMetatype) concat_layer_attributes = MultipleInputLayerAttributes(2) - concat_node = graph.add_nncf_node('concat_node', 'concat', DummyConcatMetatype, + concat_node = graph.add_nncf_node('concat_node', 'concat', dummy_types.DummyConcatMetatype, layer_attributes=concat_layer_attributes) add_node = partial(graph.add_edge_between_nncf_nodes, tensor_shape=[10] * 4, @@ -155,7 +270,7 @@ def test_convs_elementwise_source_before_concat(empty_mask_branch): to_node_id=concat_node.node_id) # Check without masks - assert DummyOpConcat.check_concat(concat_node, graph) + assert dummy_types.DummyOpConcat.check_concat(concat_node, graph) # Set masks masked_convs = [conv_op_0, conv_op_1] if not empty_mask_branch: @@ -166,9 +281,9 @@ def test_convs_elementwise_source_before_concat(empty_mask_branch): conv_op.data['output_mask'] = np.ones(10) # Propagate masks - MaskPropagationAlgorithm(graph, DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() + MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() # Check with masks concat_node = graph.get_node_by_id(concat_node.node_id) - assert DummyOpConcat.check_concat(concat_node, graph) + assert dummy_types.DummyOpConcat.check_concat(concat_node, graph) reference_mask = np.ones((20,)) np.testing.assert_equal(concat_node.data['output_mask'], reference_mask) From 13bc58c86c6d49043a817df0ea75a4cb73c8292a Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 09:27:28 +0300 Subject: [PATCH 08/19] Fix diamond inheritance --- nncf/torch/pruning/export_helpers.py | 35 +++++++++------------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index 2f661f6f944..590d6e920a1 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -73,20 +73,7 @@ PT_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") -class PTDefaultMetaOp(DefaultMetaOp): - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - """ - Propagate mask through a node using masks of all inputs and pruning mask of current node (if any). - Should set the following attributes: - input_masks - list of masks of input nodes (None if there is no mask in some input); - output_mask - resulting mask of node operation. - - :param node: Node from NNCF graph to propagate mask through it. - :param graph: Graph of model to prune. - """ - raise NotImplementedError - +class PTPruner(object): @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): """ @@ -109,24 +96,24 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('model_input') -class PTInput(PTDefaultMetaOp, OpInput): +class PTInput(OpInput, PTPruner): subtypes = [PTInputNoopMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('model_output') -class PTOutput(PTDefaultMetaOp, OpOutput): +class PTOutput(OpOutput, PTPruner): subtypes = [PTOutputNoopMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class PTIdentityMaskForwardOps(PTDefaultMetaOp, OpIdentityMaskForwardOps): +class PTIdentityMaskForwardOps(OpIdentityMaskForwardOps, PTPruner): subtypes = [HardTanhMetatype, TanhMetatype, RELUMetatype, PRELUMetatype, ELUMetatype, GELUMetatype, SigmoidMetatype, SoftmaxMetatype, AvgPool2dMetatype, MaxPool2dMetatype, DropoutMetatype] additional_types = ['h_sigmoid', 'h_swish', 'RELU'] @PT_PRUNING_OPERATOR_METATYPES.register('convolution') -class PTConvolution(PTDefaultMetaOp, OpConvolution): +class PTConvolution(OpConvolution, PTPruner): subtypes = [Conv1dMetatype, Conv2dMetatype, Conv3dMetatype] @classmethod @@ -180,7 +167,7 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class PTTransposeConvolution(PTDefaultMetaOp, OpTransposeConvolution): +class PTTransposeConvolution(OpTransposeConvolution, PTPruner): subtypes = [ConvTranspose2dMetatype, ConvTranspose3dMetatype] @classmethod @@ -227,7 +214,7 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class PTBatchNorm(PTDefaultMetaOp, OpBatchNorm): +class PTBatchNorm(OpBatchNorm, PTPruner): subtypes = [BatchNormMetatype] @classmethod @@ -253,7 +240,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('group_norm') -class GroupNorm(PTDefaultMetaOp, OpGroupNorm): +class GroupNorm(OpGroupNorm, PTPruner): subtypes = [GroupNormMetatype] @classmethod @@ -279,7 +266,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('elementwise') -class PTElementwise(PTDefaultMetaOp, OpElementwise): +class PTElementwise(OpElementwise, PTPruner): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] @classmethod @@ -313,12 +300,12 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class PTStopMaskForwardOps(PTDefaultMetaOp, OpStopMaskForwardOps): +class PTStopMaskForwardOps(OpStopMaskForwardOps, PTPruner): subtypes = [MeanMetatype, MaxMetatype, MinMetatype, LinearMetatype, MatMulMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('concat') -class PTConcat(PTDefaultMetaOp, OpConcat): +class PTConcat(OpConcat, PTPruner): subtypes = [CatMetatype] ConvolutionOp = PTConvolution From 8e621024d41ebb168bd2c7bc613274a1142778db Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 10:50:43 +0300 Subject: [PATCH 09/19] Unify export helpers --- nncf/common/pruning/export_helpers.py | 39 +++++++++++----- nncf/tensorflow/pruning/export_helpers.py | 54 +++++------------------ nncf/torch/pruning/export_helpers.py | 53 ++++------------------ 3 files changed, 50 insertions(+), 96 deletions(-) diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index ca87511f944..866ad4860e9 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -13,7 +13,7 @@ import numpy as np -from typing import Union +from typing import Union, List from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -193,14 +193,26 @@ def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: return True @classmethod - def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[np.array, None]: + def _get_unit_mask(cls, dim, device): + return np.ones(dim) + + @classmethod + def _get_masks_device(cls, input_masks): + return None + + @classmethod + def _concat_masks(cls, filled_input_masks): + return np.concatenate(filled_input_masks, 0) + + @classmethod + def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[np.array], None]: """ Generate output mask from input masks with all None replaced by identity masks. If all input masks is None return None. - :param node: Node to determine it's sources - :param graph: NNCF graph to work with - :return: Output mask + :param node: Node to determine it's sources. + :param graph: NNCF graph to work with. + :return: Filled input masks. """ input_edges = graph.get_input_edges(node) previous_nodes = [edge.from_node for edge in input_edges] @@ -209,14 +221,15 @@ def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[np.arra if all(mask is None for mask in input_masks): return None - filled_input_masks = [] for i, mask in enumerate(input_masks): if mask is None: concat_axis = node.layer_attributes.axis - mask = np.ones(input_edges[i].tensor_shape[concat_axis]) + concat_dim = input_edges[i].tensor_shape[concat_axis] + device = cls._get_masks_device(input_masks) + mask = cls._get_unit_mask(concat_dim, device) filled_input_masks.append(mask) - result_mask = np.concatenate(filled_input_masks, 0) + result_mask = cls._concat_masks(filled_input_masks) return result_mask @classmethod @@ -234,12 +247,18 @@ class OpElementwise(DefaultMetaOp): def accept_pruned_input(cls, node: NNCFNode): return True + @classmethod + def _assert_input_masks_close(cls, input_masks): + for input_mask in input_masks[1:]: + np.testing.assert_allclose(input_masks[0], input_mask) + @classmethod def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): input_masks = get_input_masks(node, graph) + + node.data['input_masks'] = input_masks if input_masks[0] is not None: - for input_mask in input_masks[1:]: - np.testing.assert_allclose(input_masks[0], input_mask) + cls._assert_input_masks_close(input_masks) node.data['output_mask'] = input_masks[0] diff --git a/nncf/tensorflow/pruning/export_helpers.py b/nncf/tensorflow/pruning/export_helpers.py index 5d3387be54c..9bab26ffece 100644 --- a/nncf/tensorflow/pruning/export_helpers.py +++ b/nncf/tensorflow/pruning/export_helpers.py @@ -12,7 +12,6 @@ """ from typing import Dict from typing import List -from typing import Union import tensorflow as tf @@ -20,9 +19,6 @@ from nncf.tensorflow.graph.pattern_operations import ELEMENTWISE_OPERATIONS from nncf.tensorflow.graph.pattern_operations import TF_ACTIVATIONS_OPERATIONS from nncf.common.graph.definitions import NNCFGraphNodeType -from nncf.common.graph import NNCFGraph -from nncf.common.graph import NNCFNode -from nncf.common.pruning.mask_propagation import get_input_masks from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry from nncf.common.pruning.export_helpers import ( OpInput, @@ -81,12 +77,9 @@ class TFElementwise(OpElementwise): additional_types = _get_types(ELEMENTWISE_OPERATIONS) @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) - if input_masks[0] is not None: - for input_mask in input_masks[1:]: - tf.debugging.assert_near(input_masks[0], input_mask) - node.data['output_mask'] = input_masks[0] + def _assert_input_masks_close(cls, input_masks): + for input_mask in input_masks[1:]: + tf.debugging.assert_near(input_masks[0], input_mask) @TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') @@ -103,38 +96,15 @@ class TFConcat(OpConcat): InputOp = TFInput @classmethod - def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[tf.Tensor, None]: - """ - Generate output mask from input masks with all None replaced by identity masks. - If all input masks is None return None. - - :param node: Node to determine it's sources - :param graph: NNCF graph to work with - :return: Output mask - """ - input_edges = graph.get_input_edges(node) - previous_nodes = [edge.from_node for edge in input_edges] - input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] - - if all(mask is None for mask in input_masks): - return None - - device = [m for m in input_masks if m is not None][0].device - - filled_input_masks = [] - for i, mask in enumerate(input_masks): - if mask is None: - with tf.device(device): - mask = tf.ones(input_edges[i].tensor_shape[-1]) - filled_input_masks.append(mask) - result_mask = tf.concat(filled_input_masks, 0) - return result_mask + def _get_unit_mask(cls, dim, device): + with tf.device(device): + mask = tf.ones(dim) + return mask @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - result_mask = None + def _get_masks_device(cls, input_masks): + return [m for m in input_masks if m is not None][0].device - if cls.check_concat(node, graph): - result_mask = cls.generate_output_mask(node, graph) - - node.data['output_mask'] = result_mask + @classmethod + def _concat_masks(cls, filled_input_masks): + return tf.concat(filled_input_masks, 0) diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index 590d6e920a1..b528d38d815 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -10,15 +10,12 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Union -from typing import List import torch from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry -from nncf.common.pruning.mask_propagation import get_input_masks from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm from nncf.torch.graph.operator_metatypes import ( AddMetatype, @@ -53,7 +50,6 @@ TanhMetatype, ) from nncf.common.pruning.export_helpers import ( - DefaultMetaOp, OpInput, OpOutput, OpIdentityMaskForwardOps, @@ -270,13 +266,9 @@ class PTElementwise(OpElementwise, PTPruner): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = get_input_masks(node, graph) + def _assert_input_masks_close(cls, input_masks): + assert all(torch.allclose(input_masks[0], mask) for mask in input_masks) - node.data['input_masks'] = input_masks - if input_masks[0] is not None: - assert all(torch.allclose(input_masks[0], mask) for mask in input_masks) - node.data['output_mask'] = input_masks[0] @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @@ -313,43 +305,16 @@ class PTConcat(OpConcat, PTPruner): InputOp = PTInput @classmethod - def fill_input_masks(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[torch.Tensor], None]: - """ - Fill input masks with all None replaced by identity masks. - If all input masks is None return None. - - :param node: Node to determine it's sources. - :param graph: NNCF graph to work with. - :return: Filled input masks. - """ - input_edges = graph.get_input_edges(node) - previous_nodes = [edge.from_node for edge in input_edges] - input_masks = [input_node.data['output_mask'] for input_node in previous_nodes] - - if all(mask is None for mask in input_masks): - return None - - device = [m for m in input_masks if m is not None][0].device - - filled_input_masks = [] - for i, mask in enumerate(input_masks): - if mask is None: - mask = torch.ones(input_edges[i].tensor_shape[1], device=device) - filled_input_masks.append(mask) - return filled_input_masks + def _get_unit_mask(cls, dim, device): + return torch.ones(dim, device=device) @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - input_masks = None - output_mask = None + def _get_masks_device(cls, input_masks): + return [m for m in input_masks if m is not None][0].device - if cls.check_concat(node, graph): - input_masks = cls.fill_input_masks(node, graph) - if input_masks: - output_mask = torch.cat(input_masks) - - node.data['input_masks'] = input_masks - node.data['output_mask'] = output_mask + @classmethod + def _concat_masks(cls, filled_input_masks): + return torch.cat(filled_input_masks, 0) class ModelPruner(MaskPropagationAlgorithm): From 7c40fc1b8f89508f189188c2b08c366979a8615d Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 11:32:48 +0300 Subject: [PATCH 10/19] Fix concat axis problems --- nncf/common/graph/utils.py | 36 ++++++++++++++++++++++++ nncf/tensorflow/graph/converter.py | 12 ++------ nncf/torch/graph/graph_builder.py | 11 ++------ tests/common/graph/test_utils.py | 21 ++++++++++++++ tests/tensorflow/test_model_converter.py | 27 +++++++++++------- 5 files changed, 78 insertions(+), 29 deletions(-) create mode 100644 nncf/common/graph/utils.py create mode 100644 tests/common/graph/test_utils.py diff --git a/nncf/common/graph/utils.py b/nncf/common/graph/utils.py new file mode 100644 index 00000000000..e0501548485 --- /dev/null +++ b/nncf/common/graph/utils.py @@ -0,0 +1,36 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import List + + +def get_concat_axis(input_shape: List[int], output_shape: List[int]) -> int: + axis = None + # If it's dummy concat of one tensor + if len(input_shape) == 1: + axis = -1 + else: + none_dim = None + for idx, (dim_in, dim_out) in enumerate(zip(input_shape[0], output_shape[0])): + if dim_in != dim_out: + axis = idx + break + elif dim_in is None: + none_dim = idx + if not axis: + axis = none_dim + + if axis is None: + raise RuntimeError('Unexpected behaviour for concat op') + + return axis diff --git a/nncf/tensorflow/graph/converter.py b/nncf/tensorflow/graph/converter.py index 145b93d8246..66d84bc41c7 100644 --- a/nncf/tensorflow/graph/converter.py +++ b/nncf/tensorflow/graph/converter.py @@ -29,6 +29,7 @@ from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.utils import get_concat_axis from nncf.common.utils.logger import logger as nncf_logger from nncf.tensorflow.graph.metatypes.common import DECONV_LAYER_METATYPES from nncf.tensorflow.graph.metatypes.common import DEPTHWISE_CONV_LAYER_METATYPES @@ -662,16 +663,7 @@ def _get_multiple_input_layer_attributes(layer: tf.keras.layers.Layer) -> Multip else: input_shape = layer.input_shape output_shape = layer.output_shape - axis = None - # If it's dummy concat of one tensor - if len(input_shape) == 1: - axis = -1 - for idx, (dim_in, dim_out) in enumerate(zip(input_shape[0], output_shape[0])): - if dim_in is None or dim_in != dim_out: - axis = idx - break - if axis is None: - raise RuntimeError('Unexpected behaviour for concat op') + axis = get_concat_axis(input_shape, output_shape) return MultipleInputLayerAttributes(axis) diff --git a/nncf/torch/graph/graph_builder.py b/nncf/torch/graph/graph_builder.py index 852a31f96b6..3a7c711fc10 100644 --- a/nncf/torch/graph/graph_builder.py +++ b/nncf/torch/graph/graph_builder.py @@ -23,6 +23,7 @@ from nncf.common.graph import LayerName from nncf.common.graph.layer_attributes import Dtype from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes +from nncf.common.graph.utils import get_concat_axis from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.graph_tracer import GraphTracer from nncf.torch.dynamic_graph.graph_tracer import ModelInputInfo @@ -100,17 +101,9 @@ def convert(dynamic_graph: DynamicGraph, input_infos: List[ModelInputInfo] = Non output_edges = nncf_graph.get_output_edges(node) # In case is intermediate node if input_edges and output_edges: - axis = None - if len(input_edges) == 1: - axis = -1 input_shape = input_edges[0].tensor_shape output_shape = output_edges[0].tensor_shape - for idx, (dim_in, dim_out) in enumerate(zip(input_shape, output_shape)): - if dim_in is None or dim_in != dim_out: - axis = idx - break - if axis is None: - raise RuntimeError('Unexpected behaviour for concat op') + axis = get_concat_axis(input_shape, output_shape) layer_attributes = MultipleInputLayerAttributes(axis) node.layer_attributes = layer_attributes return nncf_graph diff --git a/tests/common/graph/test_utils.py b/tests/common/graph/test_utils.py new file mode 100644 index 00000000000..da47ab7bce2 --- /dev/null +++ b/tests/common/graph/test_utils.py @@ -0,0 +1,21 @@ +import pytest + +from nncf.common.graph.utils import get_concat_axis + + +TEST_CASES = [ + ([(None, 1, 1, 5)], [(None, 1, 1, 5)], False, [3, -1]), + ([(None, 1, 1, 5), (None, 1, 1, 5)], [(None, 1, 1, 10)], False, [3, -1]), + ([(1, 1, None), (1, 1, None)], [(1, 1, None)], False, [2, -1]), + ([(1, 1, 32, 1), (1, 1, 32, 1)], [(1, 1, 64, 1)], False, [2, -1]), + ([(1, 1, 5), (1, 1, 5)], [(1, 1, 5)], True, None), +] + + +@pytest.mark.parametrize('input_shape,output_shape,raise_error,possible_axes', TEST_CASES) +def test_get_concat_axis(input_shape, output_shape, raise_error, possible_axes): + if not raise_error: + assert get_concat_axis(input_shape, output_shape) in possible_axes + else: + with pytest.raises(RuntimeError): + _ = get_concat_axis(input_shape, output_shape) diff --git a/tests/tensorflow/test_model_converter.py b/tests/tensorflow/test_model_converter.py index 2264c180ee7..1d7ae5922ff 100644 --- a/tests/tensorflow/test_model_converter.py +++ b/tests/tensorflow/test_model_converter.py @@ -13,6 +13,7 @@ import pytest import tensorflow as tf +from functools import partial from tensorflow.python.keras import backend from tensorflow.python.keras import layers from tensorflow.python.keras import models @@ -27,6 +28,7 @@ from tests.tensorflow.helpers import get_basic_conv_test_model from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test from tests.tensorflow.quantization.test_algorithm_quantization import get_basic_quantization_config +from tests.tensorflow.pruning.helpers import get_concat_test_model def test_struct_auxiliary_nodes_nncf_graph(): @@ -89,7 +91,7 @@ def test_get_custom_layers(): assert isinstance(custom_layers[CustomLayerForTest.CUSTOM_LAYER_NAME], CustomLayerForTest) -def ModelWithReshapesAndConcats(batch_size=None): +def get_model_with_reshapes_and_concats(batch_size=None): input =layers.Input((64, ), batch_size=batch_size) x = tf.reshape(input, (32, -1)) x = layers.Reshape((16, -1))(x) @@ -100,17 +102,22 @@ def ModelWithReshapesAndConcats(batch_size=None): return models.Model(input, y, name='ModelWithReshape') -@pytest.mark.parametrize('batch_size', [None, 8], ids=['no_batch_size', 'with_batch_size']) -def test_model_with_reshape_and_concat(batch_size): - model = ModelWithReshapesAndConcats(batch_size) - model.build((64,)) - graph = convert_keras_model_to_nncf_graph(model) - ref_concat_nodes = {'concatenate': {'axis': [-1, 2]}, +CONCAT_MODELS = [partial(get_concat_test_model, input_shape=[1, 8, 8, 1]), + get_model_with_reshapes_and_concats] +REF_CONCAT_ATTRS = [{'tf_op_layer_tf_concat_1': {'axis': [-1, 3]}, + 'tf_op_layer_tf_concat_2': {'axis': [-1, 3]}}, + {'concatenate': {'axis': [-1, 2]}, 'tf_op_layer_concat': {'axis': [-1, 2]}, - 'tf_op_layer_concat_1': {'axis': [-1, 2]}} + 'tf_op_layer_concat_1': {'axis': [-1, 2]}}] + + +@pytest.mark.parametrize('model, ref_attrs', [(m, r) for m, r in zip(CONCAT_MODELS, REF_CONCAT_ATTRS)]) +def test_model_with_reshape_and_concat(model, ref_attrs): + model = model() + graph = convert_keras_model_to_nncf_graph(model) for node in graph.get_all_nodes(): if node.metatype in LAYER_METATYPES_AGNOSTIC_TO_DATA_PRECISION_WITH_MULTIPLE_INPUTS: - assert node.node_name in ref_concat_nodes + assert node.node_name in ref_attrs assert node.layer_attributes is not None assert isinstance(node.layer_attributes, MultipleInputLayerAttributes) - assert node.layer_attributes.axis in ref_concat_nodes[node.node_name]['axis'] + assert node.layer_attributes.axis in ref_attrs[node.node_name]['axis'] From 289d4344361f3539c234aec8228801d7c3b830b9 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 12:48:36 +0300 Subject: [PATCH 11/19] Fix naming --- nncf/common/pruning/default_pruning_op.py | 55 ++++++++++++++++ nncf/common/pruning/export_helpers.py | 65 ++++--------------- nncf/common/pruning/mask_propagation.py | 4 +- nncf/common/pruning/model_analysis.py | 4 +- nncf/tensorflow/pruning/base_algorithm.py | 6 +- nncf/tensorflow/pruning/export_helpers.py | 42 ++++++------ .../pruning/filter_pruning/algorithm.py | 10 +-- nncf/torch/pruning/export_helpers.py | 46 ++++++------- nncf/torch/pruning/filter_pruning/algo.py | 4 +- tests/common/pruning/dummy_types.py | 46 ++++++------- tests/common/pruning/test_export_helpers.py | 22 +++---- .../pruning/test_model_pruning_analysis.py | 8 +-- 12 files changed, 166 insertions(+), 146 deletions(-) create mode 100644 nncf/common/pruning/default_pruning_op.py diff --git a/nncf/common/pruning/default_pruning_op.py b/nncf/common/pruning/default_pruning_op.py new file mode 100644 index 00000000000..1e757b666d8 --- /dev/null +++ b/nncf/common/pruning/default_pruning_op.py @@ -0,0 +1,55 @@ +""" + Copyright (c) 2021 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.graph import NNCFGraph + + +class DefaultPruningOp: + """ + Determines meta operations which aggregate operations having common + properties of interaction with pruning masks + """ + + subtypes = [] + additional_types = [] + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + """ + :return: accept_pruned_input - can this operation work with pruned input or not + """ + raise NotImplementedError + + @classmethod + def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): + """ + Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any). + + :param node: The graph node to propagate mask through it + :param graph: The model graph to prune + """ + raise NotImplementedError + + @classmethod + def get_all_op_aliases(cls): + """ + :return: list of all aliases of types in metatype + """ + op_types = [] + for subtype in cls.subtypes: + op_types.extend(subtype.get_all_aliases()) + op_types = list(set(op_types)) + cls.additional_types + return op_types + + diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index 866ad4860e9..3a23d428f58 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -23,47 +23,10 @@ from nncf.common.graph.layer_attributes import GroupNormLayerAttributes from nncf.common.pruning.mask_propagation import identity_mask_propagation from nncf.common.pruning.mask_propagation import get_input_masks +from nncf.common.pruning.default_pruning_op import DefaultPruningOp -class DefaultMetaOp: - """ - Determines meta operations which aggregate operations having common - properties of interaction with pruning masks - """ - - subtypes = [] - additional_types = [] - - @classmethod - def accept_pruned_input(cls, node: NNCFNode): - """ - :return: accept_pruned_input - can this operation work with pruned input or not - """ - raise NotImplementedError - - @classmethod - def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): - """ - Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any). - - :param node: The graph node to propagate mask through it - :param graph: The model graph to prune - """ - raise NotImplementedError - - @classmethod - def get_all_op_aliases(cls): - """ - :return: list of all aliases of types in metatype - """ - op_types = [] - for subtype in cls.subtypes: - op_types.extend(subtype.get_all_aliases()) - op_types = list(set(op_types)) + cls.additional_types - return op_types - - -class OpInput(DefaultMetaOp): +class InputPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): return False @@ -73,7 +36,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = None -class OpOutput(DefaultMetaOp): +class OutputPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): return True @@ -83,7 +46,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = None -class OpIdentityMaskForwardOps(DefaultMetaOp): +class IdentityMaskForwardPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): return True @@ -93,7 +56,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): identity_mask_propagation(node, graph) -class OpConvolution(DefaultMetaOp): +class ConvolutionPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): accept_pruned_input = True @@ -115,7 +78,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = output_mask -class OpTransposeConvolution(DefaultMetaOp): +class TransposeConvolutionPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): accept_pruned_input = True @@ -138,7 +101,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = output_mask -class OpBatchNorm(DefaultMetaOp): +class BatchNormPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): return True @@ -148,7 +111,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): identity_mask_propagation(node, graph) -class OpGroupNorm(DefaultMetaOp): +class GroupNormPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): # For Instance Normalization @@ -160,10 +123,10 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): identity_mask_propagation(node, graph) -class OpConcat(DefaultMetaOp): - ConvolutionOp = None # type: OpConvolution - StopMaskForwardOp = None # type: OpStopMaskForwardOps - InputOp = None # type: OpInput +class ConcatPruningOp(DefaultPruningOp): + ConvolutionOp = None # type: ConvolutionPruningOp + StopMaskForwardOp = None # type: StopMaskForwardPruningOp + InputOp = None # type: InputPruningOp @classmethod def accept_pruned_input(cls, node: NNCFNode): @@ -242,7 +205,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = result_mask -class OpElementwise(DefaultMetaOp): +class ElementwisePruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): return True @@ -262,7 +225,7 @@ def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): node.data['output_mask'] = input_masks[0] -class OpStopMaskForwardOps(DefaultMetaOp): +class StopMaskForwardPruningOp(DefaultPruningOp): @classmethod def accept_pruned_input(cls, node: NNCFNode): return False diff --git a/nncf/common/pruning/mask_propagation.py b/nncf/common/pruning/mask_propagation.py index 7fdb3da5abe..3c44e18ec13 100644 --- a/nncf/common/pruning/mask_propagation.py +++ b/nncf/common/pruning/mask_propagation.py @@ -17,6 +17,8 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry +from nncf.common.pruning.default_pruning_op import DefaultPruningOp + TensorType = TypeVar('TensorType') @@ -39,7 +41,7 @@ def __init__(self, graph: NNCFGraph, pruning_operator_metatypes: PruningOperatio self._graph = graph self._pruning_operator_metatypes = pruning_operator_metatypes - def get_meta_operation_by_type_name(self, type_name: str) -> 'DefaultMetaOp': + def get_meta_operation_by_type_name(self, type_name: str) -> DefaultPruningOp: """ Returns class of metaop that corresponds to `type_name` type. diff --git a/nncf/common/pruning/model_analysis.py b/nncf/common/pruning/model_analysis.py index 4df7f4870a4..2eeb3c309f2 100644 --- a/nncf/common/pruning/model_analysis.py +++ b/nncf/common/pruning/model_analysis.py @@ -17,7 +17,7 @@ from nncf.common.graph import NNCFNode from nncf.common.pruning.clusterization import Cluster from nncf.common.pruning.clusterization import Clusterization -from nncf.common.pruning.export_helpers import DefaultMetaOp +from nncf.common.pruning.export_helpers import DefaultPruningOp from nncf.common.pruning.utils import find_next_nodes_not_of_types from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry @@ -143,7 +143,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool: """ return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases() - def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp: + def get_meta_operation_by_type_name(self, type_name: str) -> DefaultPruningOp: """ Returns class of metaop that corresponds to `type_name` type. diff --git a/nncf/tensorflow/pruning/base_algorithm.py b/nncf/tensorflow/pruning/base_algorithm.py index 2e354cb28da..d76df6ff65f 100644 --- a/nncf/tensorflow/pruning/base_algorithm.py +++ b/nncf/tensorflow/pruning/base_algorithm.py @@ -40,8 +40,8 @@ from nncf.tensorflow.graph.transformations.layout import TFTransformationLayout from nncf.tensorflow.graph.utils import get_layer_identifier from nncf.tensorflow.graph.utils import collect_wrapped_layers -from nncf.tensorflow.pruning.export_helpers import TFElementwise -from nncf.tensorflow.pruning.export_helpers import TFIdentityMaskForwardOps +from nncf.tensorflow.pruning.export_helpers import TFElementwisePruningOp +from nncf.tensorflow.pruning.export_helpers import TFIdentityMaskForwardPruningOp from nncf.tensorflow.pruning.export_helpers import TF_PRUNING_OPERATOR_METATYPES from nncf.tensorflow.pruning.utils import get_filter_axis from nncf.tensorflow.pruning.utils import get_filters_num @@ -207,7 +207,7 @@ def _get_insertion_command_binary_mask(self, layer_name: str, @staticmethod def _get_bn_for_node(node: NNCFNode, bn_nodes: List[NNCFNode]) -> Tuple[bool, List[NNCFNode]]: is_finished = False - propagating_ops = [op_name for meta_op in [TFIdentityMaskForwardOps, TFElementwise] + propagating_ops = [op_name for meta_op in [TFIdentityMaskForwardPruningOp, TFElementwisePruningOp] for op_name in meta_op.get_all_op_aliases()] if node.node_type == 'BatchNormalization': is_finished = True diff --git a/nncf/tensorflow/pruning/export_helpers.py b/nncf/tensorflow/pruning/export_helpers.py index 9bab26ffece..6dde4065bfa 100644 --- a/nncf/tensorflow/pruning/export_helpers.py +++ b/nncf/tensorflow/pruning/export_helpers.py @@ -21,15 +21,15 @@ from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry from nncf.common.pruning.export_helpers import ( - OpInput, - OpOutput, - OpIdentityMaskForwardOps, - OpConvolution, - OpTransposeConvolution, - OpBatchNorm, - OpConcat, - OpElementwise, - OpStopMaskForwardOps + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp ) TF_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") @@ -40,17 +40,17 @@ def _get_types(operations_dict: Dict) -> List[str]: @TF_PRUNING_OPERATOR_METATYPES.register('model_input') -class TFInput(OpInput): +class TFInputPruningOp(InputPruningOp): additional_types = ['InputLayer', NNCFGraphNodeType.INPUT_NODE] @TF_PRUNING_OPERATOR_METATYPES.register('model_output') -class TFOutput(OpOutput): +class TFOutputPruningOp(OutputPruningOp): additional_types = [NNCFGraphNodeType.OUTPUT_NODE] @TF_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class TFIdentityMaskForwardOps(OpIdentityMaskForwardOps): +class TFIdentityMaskForwardPruningOp(IdentityMaskForwardPruningOp): additional_types = _get_types(KERAS_ACTIVATIONS_OPERATIONS) + _get_types(TF_ACTIVATIONS_OPERATIONS) \ + ['AvgPool2D', 'GlobalAvgPool2D', 'AveragePooling2D', 'GlobalAveragePooling2D'] \ + ['MaxPooling2D', 'GlobalMaxPooling2D', 'MaxPool2D', 'GlobalMaxPool2D'] \ @@ -58,22 +58,22 @@ class TFIdentityMaskForwardOps(OpIdentityMaskForwardOps): @TF_PRUNING_OPERATOR_METATYPES.register('convolution') -class TFConvolution(OpConvolution): +class TFConvolutionPruningOp(ConvolutionPruningOp): additional_types = ['Conv1D', 'Conv2D', 'Conv3D', 'DepthwiseConv2D'] @TF_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class TFTransposeConvolution(OpTransposeConvolution): +class TFTransposeConvolutionPruningOp(TransposeConvolutionPruningOp): additional_types = ['Conv1DTranspose', 'Conv2DTranspose', 'Conv3DTranspose'] @TF_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class TFBatchNorm(OpBatchNorm): +class TFBatchNormPruningOp(BatchNormPruningOp): additional_types = ['BatchNormalization', 'SyncBatchNormalization'] @TF_PRUNING_OPERATOR_METATYPES.register('elementwise') -class TFElementwise(OpElementwise): +class TFElementwisePruningOp(ElementwisePruningOp): additional_types = _get_types(ELEMENTWISE_OPERATIONS) @classmethod @@ -83,17 +83,17 @@ def _assert_input_masks_close(cls, input_masks): @TF_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class TFStopMaskForwardOps(OpStopMaskForwardOps): +class TFStopMaskForwardPruningOp(StopMaskForwardPruningOp): additional_types = ['Dense', 'MatMul'] @TF_PRUNING_OPERATOR_METATYPES.register('concat') -class TFConcat(OpConcat): +class TFConcatPruningOp(ConcatPruningOp): additional_types = ['Concatenate', 'ConcatV2'] - ConvolutionOp = TFConvolution - StopMaskForwardOp = TFStopMaskForwardOps - InputOp = TFInput + ConvolutionOp = TFConvolutionPruningOp + StopMaskForwardOp = TFStopMaskForwardPruningOp + InputOp = TFInputPruningOp @classmethod def _get_unit_mask(cls, dim, device): diff --git a/nncf/tensorflow/pruning/filter_pruning/algorithm.py b/nncf/tensorflow/pruning/filter_pruning/algorithm.py index 2146bb446b7..9fa2c15c5a1 100644 --- a/nncf/tensorflow/pruning/filter_pruning/algorithm.py +++ b/nncf/tensorflow/pruning/filter_pruning/algorithm.py @@ -52,9 +52,9 @@ from nncf.tensorflow.pruning.base_algorithm import BasePruningAlgoController from nncf.tensorflow.pruning.base_algorithm import PrunedLayerInfo from nncf.tensorflow.pruning.export_helpers import TF_PRUNING_OPERATOR_METATYPES -from nncf.tensorflow.pruning.export_helpers import TFConvolution -from nncf.tensorflow.pruning.export_helpers import TFElementwise -from nncf.tensorflow.pruning.export_helpers import TFTransposeConvolution +from nncf.tensorflow.pruning.export_helpers import TFConvolutionPruningOp +from nncf.tensorflow.pruning.export_helpers import TFElementwisePruningOp +from nncf.tensorflow.pruning.export_helpers import TFTransposeConvolutionPruningOp from nncf.tensorflow.pruning.filter_pruning.functions import calculate_binary_mask from nncf.tensorflow.pruning.filter_pruning.functions import FILTER_IMPORTANCE_FUNCTIONS from nncf.tensorflow.pruning.filter_pruning.functions import tensor_l2_normalizer @@ -85,11 +85,11 @@ def _is_pruned_layer(self, layer: tf.keras.layers.Layer) -> bool: return layer.__class__.__name__ in self._prunable_types def _get_op_types_of_pruned_layers(self) -> List[str]: - return [op_name for meta_op in [TFConvolution, TFTransposeConvolution] + return [op_name for meta_op in [TFConvolutionPruningOp, TFTransposeConvolutionPruningOp] for op_name in meta_op.get_all_op_aliases()] def _get_types_of_grouping_ops(self) -> List[str]: - return TFElementwise.get_all_op_aliases() + return TFElementwisePruningOp.get_all_op_aliases() @ADAPTIVE_COMPRESSION_CONTROLLERS.register('tf_filter_pruning') diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index b528d38d815..d28644fdc6c 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -50,16 +50,16 @@ TanhMetatype, ) from nncf.common.pruning.export_helpers import ( - OpInput, - OpOutput, - OpIdentityMaskForwardOps, - OpConvolution, - OpTransposeConvolution, - OpBatchNorm, - OpGroupNorm, - OpConcat, - OpElementwise, - OpStopMaskForwardOps + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + GroupNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp ) from nncf.common.utils.logger import logger as nncf_logger from nncf.torch.nncf_network import NNCFNetwork @@ -92,24 +92,24 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('model_input') -class PTInput(OpInput, PTPruner): +class PTInputPruningOp(InputPruningOp, PTPruner): subtypes = [PTInputNoopMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('model_output') -class PTOutput(OpOutput, PTPruner): +class PTOutputPruningOp(OutputPruningOp, PTPruner): subtypes = [PTOutputNoopMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('identity_mask_propagation') -class PTIdentityMaskForwardOps(OpIdentityMaskForwardOps, PTPruner): +class PTIdentityMaskForwardPruningOp(IdentityMaskForwardPruningOp, PTPruner): subtypes = [HardTanhMetatype, TanhMetatype, RELUMetatype, PRELUMetatype, ELUMetatype, GELUMetatype, SigmoidMetatype, SoftmaxMetatype, AvgPool2dMetatype, MaxPool2dMetatype, DropoutMetatype] additional_types = ['h_sigmoid', 'h_swish', 'RELU'] @PT_PRUNING_OPERATOR_METATYPES.register('convolution') -class PTConvolution(OpConvolution, PTPruner): +class PTConvolutionPruningOp(ConvolutionPruningOp, PTPruner): subtypes = [Conv1dMetatype, Conv2dMetatype, Conv3dMetatype] @classmethod @@ -163,7 +163,7 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('transpose_convolution') -class PTTransposeConvolution(OpTransposeConvolution, PTPruner): +class PTTransposeConvolutionPruningOp(TransposeConvolutionPruningOp, PTPruner): subtypes = [ConvTranspose2dMetatype, ConvTranspose3dMetatype] @classmethod @@ -210,7 +210,7 @@ def output_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('batch_norm') -class PTBatchNorm(OpBatchNorm, PTPruner): +class PTBatchNormPruningOp(BatchNormPruningOp, PTPruner): subtypes = [BatchNormMetatype] @classmethod @@ -236,7 +236,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('group_norm') -class GroupNorm(OpGroupNorm, PTPruner): +class GroupNormPruningOp(GroupNormPruningOp, PTPruner): subtypes = [GroupNormMetatype] @classmethod @@ -262,7 +262,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('elementwise') -class PTElementwise(OpElementwise, PTPruner): +class PTElementwisePruningOp(ElementwisePruningOp, PTPruner): subtypes = [AddMetatype, SubMetatype, DivMetatype, MulMetatype] @classmethod @@ -292,17 +292,17 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('stop_propagation_ops') -class PTStopMaskForwardOps(OpStopMaskForwardOps, PTPruner): +class PTStopMaskForwardPruningOp(StopMaskForwardPruningOp, PTPruner): subtypes = [MeanMetatype, MaxMetatype, MinMetatype, LinearMetatype, MatMulMetatype] @PT_PRUNING_OPERATOR_METATYPES.register('concat') -class PTConcat(OpConcat, PTPruner): +class PTConcatPruningOp(ConcatPruningOp, PTPruner): subtypes = [CatMetatype] - ConvolutionOp = PTConvolution - StopMaskForwardOp = PTStopMaskForwardOps - InputOp = PTInput + ConvolutionOp = PTConvolutionPruningOp + StopMaskForwardOp = PTStopMaskForwardPruningOp + InputOp = PTInputPruningOp @classmethod def _get_unit_mask(cls, dim, device): diff --git a/nncf/torch/pruning/filter_pruning/algo.py b/nncf/torch/pruning/filter_pruning/algo.py index 95c53a24fa7..54131acfff3 100644 --- a/nncf/torch/pruning/filter_pruning/algo.py +++ b/nncf/torch/pruning/filter_pruning/algo.py @@ -61,7 +61,7 @@ from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.pruning.base_algo import BasePruningAlgoBuilder from nncf.torch.pruning.base_algo import BasePruningAlgoController -from nncf.torch.pruning.export_helpers import PTElementwise +from nncf.torch.pruning.export_helpers import PTElementwisePruningOp from nncf.torch.pruning.export_helpers import PT_PRUNING_OPERATOR_METATYPES from nncf.torch.pruning.filter_pruning.functions import FILTER_IMPORTANCE_FUNCTIONS from nncf.torch.pruning.filter_pruning.functions import calculate_binary_mask @@ -109,7 +109,7 @@ def get_op_types_of_pruned_modules(self) -> List[str]: return types def get_types_of_grouping_ops(self) -> List[str]: - return PTElementwise.get_all_op_aliases() + return PTElementwisePruningOp.get_all_op_aliases() @ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_filter_pruning') diff --git a/tests/common/pruning/dummy_types.py b/tests/common/pruning/dummy_types.py index 2eff563a64a..faeb26a4584 100644 --- a/tests/common/pruning/dummy_types.py +++ b/tests/common/pruning/dummy_types.py @@ -3,16 +3,16 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry from nncf.common.pruning.export_helpers import ( - OpInput, - OpOutput, - OpIdentityMaskForwardOps, - OpConvolution, - OpTransposeConvolution, - OpBatchNorm, - OpGroupNorm, - OpConcat, - OpElementwise, - OpStopMaskForwardOps, + InputPruningOp, + OutputPruningOp, + IdentityMaskForwardPruningOp, + ConvolutionPruningOp, + TransposeConvolutionPruningOp, + BatchNormPruningOp, + GroupNormPruningOp, + ConcatPruningOp, + ElementwisePruningOp, + StopMaskForwardPruningOp, ) @@ -68,53 +68,53 @@ class DummyStopPropoagtionMetatype(OperatorMetatype): @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyInputMetatype.name) -class DummyOpInput(OpInput): +class DummyInputPruningOp(InputPruningOp): additional_types = [DummyInputMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyOutputMetatype.name) -class DummyOpOutput(OpOutput): +class DummyOutputPruningOp(OutputPruningOp): additional_types = [DummyOutputMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DymmyIdentityMaskForwardMetatype.name) -class DummyOpIdentityMaskForward(OpIdentityMaskForwardOps): +class DummyIdentityMaskForward(IdentityMaskForwardPruningOp): additional_types = [DymmyIdentityMaskForwardMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyStopPropoagtionMetatype.name) -class DummyOpStopMaskForward(OpStopMaskForwardOps): +class DummyStopMaskForward(StopMaskForwardPruningOp): additional_types = [DummyStopPropoagtionMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConvMetatype.name) -class DummyOpConv(OpConvolution): +class DummyConvPruningOp(ConvolutionPruningOp): additional_types = [DummyConvMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyTransposeConvolutionMetatype.name) -class DummyOpTransposeConv(OpTransposeConvolution): +class DummyTransposeConvPruningOp(TransposeConvolutionPruningOp): additional_types = [DummyTransposeConvolutionMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyBatchNormMetatype.name) -class DummyOpBatchNorm(OpBatchNorm): +class DummyBatchNormPruningOp(BatchNormPruningOp): additional_types = [DummyBatchNormMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyGroupNormMetatype.name) -class DummyOpGroupNorm(OpGroupNorm): +class DummyGroupNormPruningOp(GroupNormPruningOp): additional_types = [DummyGroupNormMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyElementwiseMetatype.name) -class DummyOpElementwise(OpElementwise): +class DummyElementwisePruningOp(ElementwisePruningOp): additional_types = [DummyElementwiseMetatype.name] @DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyConcatMetatype.name) -class DummyOpConcat(OpConcat): - ConvolutionOp = DummyOpConv - StopMaskForwardOp = DummyOpStopMaskForward - InputOp = DummyOpInput +class DummyConcatPruningOp(ConcatPruningOp): + ConvolutionOp = DummyConvPruningOp + StopMaskForwardOp = DummyStopMaskForward + InputOp = DummyInputPruningOp additional_types = [DummyConcatMetatype.name] diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py index ddb0f6a8380..e767a733096 100644 --- a/tests/common/pruning/test_export_helpers.py +++ b/tests/common/pruning/test_export_helpers.py @@ -11,13 +11,13 @@ from nncf.common.graph.layer_attributes import MultipleInputLayerAttributes from nncf.common.graph.layer_attributes import GroupNormLayerAttributes from nncf.common.graph.layer_attributes import ConvolutionLayerAttributes -from nncf.common.pruning.export_helpers import DefaultMetaOp +from nncf.common.pruning.export_helpers import DefaultPruningOp from nncf.common.pruning.mask_propagation import MaskPropagationAlgorithm -@pytest.mark.parametrize('dummy_op_class,accept_pruned_input', [(dummy_types.DummyOpInput, False), - (dummy_types.DummyOpOutput, True), - (dummy_types.DummyOpStopMaskForward, False)]) +@pytest.mark.parametrize('dummy_op_class,accept_pruned_input', [(dummy_types.DummyInputPruningOp, False), + (dummy_types.DummyOutputPruningOp, True), + (dummy_types.DummyStopMaskForward, False)]) def test_stop_propagate_ops(dummy_op_class, accept_pruned_input): node = NNCFNode(0, 'dummy_node') assert dummy_op_class.accept_pruned_input(node) == accept_pruned_input @@ -25,7 +25,7 @@ def test_stop_propagate_ops(dummy_op_class, accept_pruned_input): assert node.data['output_mask'] is None -@pytest.mark.parametrize('dummy_op_class', [dummy_types.DummyOpIdentityMaskForward, dummy_types.DummyOpBatchNorm]) +@pytest.mark.parametrize('dummy_op_class', [dummy_types.DummyIdentityMaskForward, dummy_types.DummyBatchNormPruningOp]) def test_identity_mask_propogation_prune_ops(dummy_op_class): assert dummy_op_class.accept_pruned_input(None) graph = NNCFGraph() @@ -107,7 +107,7 @@ def test_group_norm_pruning_ops(num_channels, num_groups, accept_pruned_input_re group_norm_op = graph.add_nncf_node('identity', dummy_types.DummyGroupNormMetatype.name, dummy_types.DummyGroupNormMetatype, layer_attributes=group_norm_layer_attributes) - assert dummy_types.DummyOpGroupNorm.accept_pruned_input(group_norm_op) == accept_pruned_input_ref + assert dummy_types.DummyGroupNormPruningOp.accept_pruned_input(group_norm_op) == accept_pruned_input_ref graph.add_edge_between_nncf_nodes( from_node_id=conv_op.node_id, to_node_id=group_norm_op.node_id, @@ -129,7 +129,7 @@ class DummyMaskProducerMetatype(dummy_types.DummyDefaultMetatype): @dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyMaskProducerMetatype.name) -class MockOpMaskProducer(DefaultMetaOp): +class MockOpMaskProducer(DefaultPruningOp): additional_types = [DummyMaskProducerMetatype.name] @classmethod def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): @@ -172,7 +172,7 @@ def test_conv_pruning_ops(transpose, layer_attributes, ref_accept_pruned_input, input_port_id=0, output_port_id=0, dtype=Dtype.FLOAT) - pruning_op_class = dummy_types.DummyOpTransposeConv if transpose else dummy_types.DummyOpConv + pruning_op_class = dummy_types.DummyTransposeConvPruningOp if transpose else dummy_types.DummyConvPruningOp assert pruning_op_class.accept_pruned_input(conv_op_target) == ref_accept_pruned_input ones_input_mask = np.ones((layer_attributes['in_channels'],)) ones_output_mask = np.ones((layer_attributes['out_channels'],)) @@ -231,7 +231,7 @@ def test_stop_ops_elementwise_source_before_concat(with_elementwise): add_node(from_node_id=elementwise_op.node_id, to_node_id=concat_node.node_id) - assert not dummy_types.DummyOpConcat.check_concat(concat_node, graph) + assert not dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() concat_node = graph.get_node_by_id(concat_node.node_id) assert concat_node.data['output_mask'] is None @@ -270,7 +270,7 @@ def test_convs_elementwise_source_before_concat(empty_mask_branch): to_node_id=concat_node.node_id) # Check without masks - assert dummy_types.DummyOpConcat.check_concat(concat_node, graph) + assert dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) # Set masks masked_convs = [conv_op_0, conv_op_1] if not empty_mask_branch: @@ -284,6 +284,6 @@ def test_convs_elementwise_source_before_concat(empty_mask_branch): MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() # Check with masks concat_node = graph.get_node_by_id(concat_node.node_id) - assert dummy_types.DummyOpConcat.check_concat(concat_node, graph) + assert dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) reference_mask = np.ones((20,)) np.testing.assert_equal(concat_node.data['output_mask'], reference_mask) diff --git a/tests/torch/pruning/test_model_pruning_analysis.py b/tests/torch/pruning/test_model_pruning_analysis.py index a6a42f68277..03cd8fcdff1 100644 --- a/tests/torch/pruning/test_model_pruning_analysis.py +++ b/tests/torch/pruning/test_model_pruning_analysis.py @@ -30,8 +30,8 @@ from nncf.torch.dynamic_graph.graph_tracer import ModelInputInfo from nncf.torch.layers import NNCF_PRUNING_MODULES_DICT from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.pruning.export_helpers import PTElementwise -from nncf.torch.pruning.export_helpers import PTIdentityMaskForwardOps +from nncf.torch.pruning.export_helpers import PTElementwisePruningOp +from nncf.torch.pruning.export_helpers import PTIdentityMaskForwardPruningOp from nncf.torch.pruning.export_helpers import PT_PRUNING_OPERATOR_METATYPES from nncf.common.pruning.utils import is_depthwise_conv from nncf.torch.pruning.filter_pruning.algo import FilterPruningBuilder @@ -195,7 +195,7 @@ def test_pruning_node_selector(test_input_info_struct_: GroupPruningModulesTestS prune_first, prune_last, prune_downsample = test_input_info_struct_.prune_params pruning_operations = [v.op_func_name for v in NNCF_PRUNING_MODULES_DICT] - grouping_operations = PTElementwise.get_all_op_aliases() + grouping_operations = PTElementwisePruningOp.get_all_op_aliases() from nncf.common.pruning.pruning_node_selector import PruningNodeSelector pruning_node_selector = PruningNodeSelector(PT_PRUNING_OPERATOR_METATYPES, pruning_operations, @@ -264,7 +264,7 @@ def test_group_special_nodes(test_special_ops_struct: GroupSpecialModulesTestStr special_ops_clusterization = cluster_special_ops(nncf_model.get_original_graph(), algo_builder.get_types_of_grouping_ops(), - PTIdentityMaskForwardOps.get_all_op_aliases()) + PTIdentityMaskForwardPruningOp.get_all_op_aliases()) for ref_cluster in test_special_ops_struct.eltwise_clusters: cluster = special_ops_clusterization.get_cluster_containing_element(ref_cluster[0]) From ebb7e217e1f95065d7613444487617b227dc458f Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 13:04:35 +0300 Subject: [PATCH 12/19] Fix concat axis calculation --- nncf/common/graph/utils.py | 13 ++++++++++--- nncf/torch/graph/graph_builder.py | 6 +++--- tests/common/pruning/test_export_helpers.py | 1 - 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/nncf/common/graph/utils.py b/nncf/common/graph/utils.py index e0501548485..4f1ebbfc731 100644 --- a/nncf/common/graph/utils.py +++ b/nncf/common/graph/utils.py @@ -14,14 +14,21 @@ from typing import List -def get_concat_axis(input_shape: List[int], output_shape: List[int]) -> int: +def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int]]) -> int: + """ + Returns concatenation axis by given input and output shape of concat node. + + :param input_shapes: Input_shapes of given concat node. + :param output_shapes: Input_shapes of given concat node. + :returns: Concatenation axis of given concat node. + """ axis = None # If it's dummy concat of one tensor - if len(input_shape) == 1: + if len(input_shapes) == 1: axis = -1 else: none_dim = None - for idx, (dim_in, dim_out) in enumerate(zip(input_shape[0], output_shape[0])): + for idx, (dim_in, dim_out) in enumerate(zip(input_shapes[0], output_shapes[0])): if dim_in != dim_out: axis = idx break diff --git a/nncf/torch/graph/graph_builder.py b/nncf/torch/graph/graph_builder.py index 3a7c711fc10..d5f64d12a26 100644 --- a/nncf/torch/graph/graph_builder.py +++ b/nncf/torch/graph/graph_builder.py @@ -101,9 +101,9 @@ def convert(dynamic_graph: DynamicGraph, input_infos: List[ModelInputInfo] = Non output_edges = nncf_graph.get_output_edges(node) # In case is intermediate node if input_edges and output_edges: - input_shape = input_edges[0].tensor_shape - output_shape = output_edges[0].tensor_shape - axis = get_concat_axis(input_shape, output_shape) + input_shapes = [edge.tensor_shape for edge in input_edges] + output_shapes = [edge.tensor_shape for edge in output_edges] + axis = get_concat_axis(input_shapes, output_shapes) layer_attributes = MultipleInputLayerAttributes(axis) node.layer_attributes = layer_attributes return nncf_graph diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py index e767a733096..08ea56ae119 100644 --- a/tests/common/pruning/test_export_helpers.py +++ b/tests/common/pruning/test_export_helpers.py @@ -193,7 +193,6 @@ def test_conv_pruning_ops(transpose, layer_attributes, ref_accept_pruned_input, assert np.all(conv_op_target.data['output_mask'] == input_mask) - @pytest.mark.parametrize('with_elementwise', [False, True]) def test_stop_ops_elementwise_source_before_concat(with_elementwise): graph = NNCFGraph() From e022a1f496aaa8bfcca7c4173970ad99cf642047 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 13:13:37 +0300 Subject: [PATCH 13/19] Fix pylint --- nncf/common/graph/utils.py | 2 +- nncf/common/pruning/default_pruning_op.py | 2 -- nncf/torch/pruning/export_helpers.py | 4 ++-- tests/common/pruning/test_export_helpers.py | 7 ++++++- tests/tensorflow/test_model_converter.py | 13 +++++++------ 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/nncf/common/graph/utils.py b/nncf/common/graph/utils.py index 4f1ebbfc731..5f94244d7b0 100644 --- a/nncf/common/graph/utils.py +++ b/nncf/common/graph/utils.py @@ -32,7 +32,7 @@ def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int] if dim_in != dim_out: axis = idx break - elif dim_in is None: + if dim_in is None: none_dim = idx if not axis: axis = none_dim diff --git a/nncf/common/pruning/default_pruning_op.py b/nncf/common/pruning/default_pruning_op.py index 1e757b666d8..b0949771c36 100644 --- a/nncf/common/pruning/default_pruning_op.py +++ b/nncf/common/pruning/default_pruning_op.py @@ -51,5 +51,3 @@ def get_all_op_aliases(cls): op_types.extend(subtype.get_all_aliases()) op_types = list(set(op_types)) + cls.additional_types return op_types - - diff --git a/nncf/torch/pruning/export_helpers.py b/nncf/torch/pruning/export_helpers.py index d28644fdc6c..2c19f387ac2 100644 --- a/nncf/torch/pruning/export_helpers.py +++ b/nncf/torch/pruning/export_helpers.py @@ -69,7 +69,7 @@ PT_PRUNING_OPERATOR_METATYPES = PruningOperationsMetatypeRegistry("operator_metatypes") -class PTPruner(object): +class PTPruner: @classmethod def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): """ @@ -236,7 +236,7 @@ def input_prune(cls, model: NNCFNetwork, node: NNCFNode, graph: NNCFGraph): @PT_PRUNING_OPERATOR_METATYPES.register('group_norm') -class GroupNormPruningOp(GroupNormPruningOp, PTPruner): +class PTGroupNormPruningOp(GroupNormPruningOp, PTPruner): subtypes = [GroupNormMetatype] @classmethod diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py index 08ea56ae119..30e66ee761b 100644 --- a/tests/common/pruning/test_export_helpers.py +++ b/tests/common/pruning/test_export_helpers.py @@ -3,8 +3,8 @@ from functools import partial -import tests.common.pruning.dummy_types as dummy_types +from tests.common.pruning import dummy_types from nncf.common.graph.graph import NNCFNode from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.layer_attributes import Dtype @@ -131,6 +131,11 @@ class DummyMaskProducerMetatype(dummy_types.DummyDefaultMetatype): @dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES.register(DummyMaskProducerMetatype.name) class MockOpMaskProducer(DefaultPruningOp): additional_types = [DummyMaskProducerMetatype.name] + + @classmethod + def accept_pruned_input(cls, node: NNCFNode): + pass + @classmethod def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): pass diff --git a/tests/tensorflow/test_model_converter.py b/tests/tensorflow/test_model_converter.py index 1d7ae5922ff..6402280ff9c 100644 --- a/tests/tensorflow/test_model_converter.py +++ b/tests/tensorflow/test_model_converter.py @@ -92,14 +92,15 @@ def test_get_custom_layers(): def get_model_with_reshapes_and_concats(batch_size=None): - input =layers.Input((64, ), batch_size=batch_size) - x = tf.reshape(input, (32, -1)) + inputs = layers.Input((64, ), batch_size=batch_size) + x = tf.reshape(inputs, (32, -1)) x = layers.Reshape((16, -1))(x) ones = tf.ones_like(x) t1 = layers.concatenate([x, ones]) + # pylint: disable=E1120,E1123 t2 = tf.concat([x, ones], axis=-1) y = tf.concat([t1, t2], axis=-1) - return models.Model(input, y, name='ModelWithReshape') + return models.Model(inputs, y, name='ModelWithReshape') CONCAT_MODELS = [partial(get_concat_test_model, input_shape=[1, 8, 8, 1]), @@ -107,11 +108,11 @@ def get_model_with_reshapes_and_concats(batch_size=None): REF_CONCAT_ATTRS = [{'tf_op_layer_tf_concat_1': {'axis': [-1, 3]}, 'tf_op_layer_tf_concat_2': {'axis': [-1, 3]}}, {'concatenate': {'axis': [-1, 2]}, - 'tf_op_layer_concat': {'axis': [-1, 2]}, - 'tf_op_layer_concat_1': {'axis': [-1, 2]}}] + 'tf_op_layer_concat': {'axis': [-1, 2]}, + 'tf_op_layer_concat_1': {'axis': [-1, 2]}}] -@pytest.mark.parametrize('model, ref_attrs', [(m, r) for m, r in zip(CONCAT_MODELS, REF_CONCAT_ATTRS)]) +@pytest.mark.parametrize('model, ref_attrs', list(zip(CONCAT_MODELS, REF_CONCAT_ATTRS))) def test_model_with_reshape_and_concat(model, ref_attrs): model = model() graph = convert_keras_model_to_nncf_graph(model) From 7c57aa967bca2d643e8553f15d710afbfcc04e6a Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 16:58:52 +0300 Subject: [PATCH 14/19] Process stack operation --- nncf/common/graph/utils.py | 28 ++++++++++++++-------------- nncf/torch/graph/graph_builder.py | 5 ++++- tests/common/graph/test_utils.py | 21 +++++++++------------ 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/nncf/common/graph/utils.py b/nncf/common/graph/utils.py index 5f94244d7b0..23a8cee862a 100644 --- a/nncf/common/graph/utils.py +++ b/nncf/common/graph/utils.py @@ -13,6 +13,8 @@ from typing import List +from nncf.common.utils.logger import logger + def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int]]) -> int: """ @@ -23,21 +25,19 @@ def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int] :returns: Concatenation axis of given concat node. """ axis = None - # If it's dummy concat of one tensor - if len(input_shapes) == 1: - axis = -1 - else: - none_dim = None - for idx, (dim_in, dim_out) in enumerate(zip(input_shapes[0], output_shapes[0])): - if dim_in != dim_out: - axis = idx - break - if dim_in is None: - none_dim = idx - if not axis: - axis = none_dim + none_dim = None + for idx, (dim_in, dim_out) in enumerate(zip(input_shapes[0], output_shapes[0])): + if dim_in != dim_out: + axis = idx + break + if dim_in is None: + none_dim = idx if axis is None: - raise RuntimeError('Unexpected behaviour for concat op') + if none_dim is None: + axis = -1 + logger.warning('Identity concat node detected') + else: + axis = none_dim return axis diff --git a/nncf/torch/graph/graph_builder.py b/nncf/torch/graph/graph_builder.py index d5f64d12a26..b22310beee9 100644 --- a/nncf/torch/graph/graph_builder.py +++ b/nncf/torch/graph/graph_builder.py @@ -99,10 +99,13 @@ def convert(dynamic_graph: DynamicGraph, input_infos: List[ModelInputInfo] = Non if node.metatype is CatMetatype: input_edges = nncf_graph.get_input_edges(node) output_edges = nncf_graph.get_output_edges(node) - # In case is intermediate node + # Case of intermediate node if input_edges and output_edges: input_shapes = [edge.tensor_shape for edge in input_edges] output_shapes = [edge.tensor_shape for edge in output_edges] + # Case node is stack + if len(input_shapes[0]) != len(output_shapes[0]): + continue axis = get_concat_axis(input_shapes, output_shapes) layer_attributes = MultipleInputLayerAttributes(axis) node.layer_attributes = layer_attributes diff --git a/tests/common/graph/test_utils.py b/tests/common/graph/test_utils.py index da47ab7bce2..61183ab8e35 100644 --- a/tests/common/graph/test_utils.py +++ b/tests/common/graph/test_utils.py @@ -4,18 +4,15 @@ TEST_CASES = [ - ([(None, 1, 1, 5)], [(None, 1, 1, 5)], False, [3, -1]), - ([(None, 1, 1, 5), (None, 1, 1, 5)], [(None, 1, 1, 10)], False, [3, -1]), - ([(1, 1, None), (1, 1, None)], [(1, 1, None)], False, [2, -1]), - ([(1, 1, 32, 1), (1, 1, 32, 1)], [(1, 1, 64, 1)], False, [2, -1]), - ([(1, 1, 5), (1, 1, 5)], [(1, 1, 5)], True, None), + ([(None, 1, 1, 5)], [(None, 1, 1, 7)], [3, -1]), + ([(None, 1, 1, 5), (None, 1, 1, 5)], [(None, 1, 1, 10)], [3, -1]), + ([(1, 1, None), (1, 1, None)], [(1, 1, None)], [2, -1]), + ([(1, 1, 32, 1), (1, 1, 32, 1)], [(1, 1, 64, 1)], [2, -1]), + ([(1, 1, 5), (1, 1, 5)], [(1, 1, 5)], [-1]), ] -@pytest.mark.parametrize('input_shape,output_shape,raise_error,possible_axes', TEST_CASES) -def test_get_concat_axis(input_shape, output_shape, raise_error, possible_axes): - if not raise_error: - assert get_concat_axis(input_shape, output_shape) in possible_axes - else: - with pytest.raises(RuntimeError): - _ = get_concat_axis(input_shape, output_shape) +@pytest.mark.parametrize('input_shape,output_shape,possible_axes', TEST_CASES) +def test_get_concat_axis(input_shape, output_shape, possible_axes): + axis = get_concat_axis(input_shape, output_shape) + assert axis in possible_axes From 7d618ce97fa00e194839adb2823250a787b5fbab Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 5 Oct 2021 17:05:14 +0300 Subject: [PATCH 15/19] Add test case with concat zero axis --- tests/common/graph/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/common/graph/test_utils.py b/tests/common/graph/test_utils.py index 61183ab8e35..0f1b401cd76 100644 --- a/tests/common/graph/test_utils.py +++ b/tests/common/graph/test_utils.py @@ -4,6 +4,7 @@ TEST_CASES = [ + ([(1, 1), (1, 1)], [(2, 1)], [0]), ([(None, 1, 1, 5)], [(None, 1, 1, 7)], [3, -1]), ([(None, 1, 1, 5), (None, 1, 1, 5)], [(None, 1, 1, 10)], [3, -1]), ([(1, 1, None), (1, 1, None)], [(1, 1, None)], [2, -1]), From 5223436de2201c6ebf75ba7713300bd8ac240a79 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Wed, 6 Oct 2021 16:48:24 +0300 Subject: [PATCH 16/19] Add concat test case with different branch channel dim --- tests/common/pruning/test_export_helpers.py | 29 ++++++++++++--------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/common/pruning/test_export_helpers.py b/tests/common/pruning/test_export_helpers.py index 30e66ee761b..1b437cce286 100644 --- a/tests/common/pruning/test_export_helpers.py +++ b/tests/common/pruning/test_export_helpers.py @@ -181,6 +181,7 @@ def test_conv_pruning_ops(transpose, layer_attributes, ref_accept_pruned_input, assert pruning_op_class.accept_pruned_input(conv_op_target) == ref_accept_pruned_input ones_input_mask = np.ones((layer_attributes['in_channels'],)) ones_output_mask = np.ones((layer_attributes['out_channels'],)) + # Check all combinations of masks for input_mask in [None, ones_input_mask]: for output_mask in [None, ones_output_mask]: dummy_op_before = graph.get_node_by_id(dummy_op_before.node_id) @@ -242,7 +243,8 @@ def test_stop_ops_elementwise_source_before_concat(with_elementwise): @pytest.mark.parametrize('empty_mask_branch', [False, True]) -def test_convs_elementwise_source_before_concat(empty_mask_branch): +@pytest.mark.parametrize('right_branch_output_channels', [5, 10]) +def test_convs_elementwise_source_before_concat(empty_mask_branch, right_branch_output_channels): graph = NNCFGraph() conv_op_0 = graph.add_nncf_node('conv_op_0', 'conv', dummy_types.DummyConvMetatype) conv_op_1 = graph.add_nncf_node('conv_op_1', 'conv', dummy_types.DummyConvMetatype) @@ -252,42 +254,45 @@ def test_convs_elementwise_source_before_concat(empty_mask_branch): concat_node = graph.add_nncf_node('concat_node', 'concat', dummy_types.DummyConcatMetatype, layer_attributes=concat_layer_attributes) add_node = partial(graph.add_edge_between_nncf_nodes, - tensor_shape=[10] * 4, input_port_id=0, output_port_id=0, dtype=Dtype.FLOAT) # conv_op_0 -> elementwise_node add_node(from_node_id=conv_op_0.node_id, - to_node_id=elementwise_node.node_id) + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4) # conv_op_1 -> elementwise_node add_node(from_node_id=conv_op_1.node_id, - to_node_id=elementwise_node.node_id) + to_node_id=elementwise_node.node_id, + tensor_shape=[10] * 4) # elementwise_node -> concat_node add_node(from_node_id=elementwise_node.node_id, - to_node_id=concat_node.node_id) + to_node_id=concat_node.node_id, + tensor_shape=[10] * 4) # conv_op_2 -> concat_node add_node(from_node_id=conv_op_2.node_id, - to_node_id=concat_node.node_id) + to_node_id=concat_node.node_id, + tensor_shape=[10, 10, right_branch_output_channels, 10]) # Check without masks assert dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) # Set masks - masked_convs = [conv_op_0, conv_op_1] - if not empty_mask_branch: - masked_convs.append(conv_op_2) - - for conv_op in masked_convs: + for conv_op in [conv_op_0, conv_op_1]: conv_op = graph.get_node_by_id(conv_op.node_id) conv_op.data['output_mask'] = np.ones(10) + if not empty_mask_branch: + conv_op = graph.get_node_by_id(conv_op_2.node_id) + conv_op.data['output_mask'] = np.ones(right_branch_output_channels) + # Propagate masks MaskPropagationAlgorithm(graph, dummy_types.DUMMY_PRUNING_OPERATOR_METATYPES).mask_propagation() # Check with masks concat_node = graph.get_node_by_id(concat_node.node_id) assert dummy_types.DummyConcatPruningOp.check_concat(concat_node, graph) - reference_mask = np.ones((20,)) + reference_mask = np.ones((10 + right_branch_output_channels,)) np.testing.assert_equal(concat_node.data['output_mask'], reference_mask) From 03b14ba6ca01519c22a58c961d65acd6695b11bb Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Thu, 7 Oct 2021 10:41:23 +0300 Subject: [PATCH 17/19] Check all models with concat in torch --- tests/torch/pruning/test_concat.py | 59 ++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/torch/pruning/test_concat.py diff --git a/tests/torch/pruning/test_concat.py b/tests/torch/pruning/test_concat.py new file mode 100644 index 00000000000..b549f21fe9a --- /dev/null +++ b/tests/torch/pruning/test_concat.py @@ -0,0 +1,59 @@ +import numpy as np +import pytest + +from functools import partial + +from nncf.common.pruning.schedulers import BaselinePruningScheduler, ExponentialWithBiasPruningScheduler +from tests.torch.pruning.helpers import get_pruning_baseline_config, PruningTestModel, get_pruning_exponential_config +from tests.torch.helpers import create_compressed_model_and_algo_for_test +from tests.torch import test_models +from tests.torch.test_models.synthetic import EmbeddingCatLinearModel +from tests.torch.test_models.googlenet import GoogLeNet +from tests.torch.test_models.sr_small_model import SmallModel + + +MODELS = [ + {'model': EmbeddingCatLinearModel, + 'input_shape': (1, 10)}, + {'model': test_models.densenet_cifar, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.DPN26, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.DPN92, + 'input_shape': (1, 3, 32, 32)}, + {'model': GoogLeNet, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.inception_v3, + 'input_shape': (1, 3, 229, 229)}, + {'model': test_models.PNASNetA, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.PNASNetB, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.ShuffleNetG3, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.ShuffleNetG2, + 'input_shape': (1, 3, 32, 32)}, + {'model': partial(test_models.ShuffleNetV2, net_size=1), + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.squeezenet1_0, + 'input_shape': (1, 3, 32, 32)}, + {'model': test_models.squeezenet1_1, + 'input_shape': (1, 3, 32, 32)}, + {'model': SmallModel, + 'input_shape': ()}, + {'model': test_models.UNet, + 'input_shape': (1, 3, 360, 480)}, +] + +SKIP_LIST = [SmallModel, EmbeddingCatLinearModel] + + +@pytest.mark.parametrize('model,input_shape', [list(elem.values()) for elem in MODELS]) +def test_models_with_concat(model, input_shape): + if model in SKIP_LIST: + pytest.skip() + + config = get_pruning_baseline_config(list(input_shape)) + config['compression']['algorithm'] = 'filter_pruning' + model = model() + _, compression_ctrl = create_compressed_model_and_algo_for_test(model, config) From a4c3531d6f45a90c492bc946d3bf89c452a21354 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Thu, 7 Oct 2021 11:28:11 +0300 Subject: [PATCH 18/19] Check all models with concat in tf / make changes in export_helpers.py to make test clearer --- nncf/common/pruning/export_helpers.py | 23 ++++++++++++++++++++--- tests/tensorflow/pruning/test_concat.py | 25 +++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 tests/tensorflow/pruning/test_concat.py diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index 3a23d428f58..f0005889407 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -144,14 +144,31 @@ def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: for input_node in graph.get_previous_nodes(node): # If input has mask -> it went from convolution (source of this node is a convolution) + node_has_mask = False if input_node.data.get('output_mask', None) is not None: - continue + node_has_mask = True source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() + cls.StopMaskForwardOp.get_all_op_aliases() + cls.InputOp.get_all_op_aliases()) - sources_types = [node.node_type for node in source_nodes] + [input_node.node_type] - if any(t in sources_types for t in cls.StopMaskForwardOp.get_all_op_aliases()): + + source_types_old = [node.node_type for node in source_nodes] + sources_types_new = source_types_old + [input_node.node_type] + + decision_old_on_sources = any(t in source_types_old for t in cls.StopMaskForwardOp.get_all_op_aliases()) + decision_old = decision_old_on_sources and node_has_mask + + decision_new_on_sources = any(t in sources_types_new for t in cls.StopMaskForwardOp.get_all_op_aliases()) + decision_new = decision_new_on_sources and not node_has_mask + + if decision_new != decision_old: + is_on_sources_equal = decision_new_on_sources == decision_old_on_sources + if not is_on_sources_equal: + raise ValueError('ALERT') + + print(f'is_on_sources_equal = {is_on_sources_equal}') + print('behaviour changed!!!') + if decision_new: return False return True diff --git a/tests/tensorflow/pruning/test_concat.py b/tests/tensorflow/pruning/test_concat.py new file mode 100644 index 00000000000..291793983dc --- /dev/null +++ b/tests/tensorflow/pruning/test_concat.py @@ -0,0 +1,25 @@ +import pytest + +from tests.tensorflow.helpers import create_compressed_model_and_algo_for_test +from tests.tensorflow.pruning.helpers import get_basic_pruning_config +from tests.tensorflow import test_models + + +MODELS = [ + {'model': test_models.InceptionV3, + 'input_shape': (75, 75, 3)}, + {'model': test_models.InceptionResNetV2, + 'input_shape': (75, 75, 3)}, + {'model': test_models.NASNetMobile, + 'input_shape': (32, 32, 3)}, + {'model': test_models.DenseNet121, + 'input_shape': (32, 32, 3)}, +] + + +@pytest.mark.parametrize('model,input_shape', [list(elem.values()) for elem in MODELS]) +def test_concat(model, input_shape): + config = get_basic_pruning_config(input_shape[1]) + model = model(list(input_shape)) + + model, _ = create_compressed_model_and_algo_for_test(model, config) From 6b345842f7cd910842d841e939629f8199017718 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Thu, 7 Oct 2021 14:26:30 +0300 Subject: [PATCH 19/19] Check old `check_concat` method is always True --- nncf/common/pruning/export_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nncf/common/pruning/export_helpers.py b/nncf/common/pruning/export_helpers.py index f0005889407..8dcaab9e2b3 100644 --- a/nncf/common/pruning/export_helpers.py +++ b/nncf/common/pruning/export_helpers.py @@ -168,7 +168,7 @@ def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool: print(f'is_on_sources_equal = {is_on_sources_equal}') print('behaviour changed!!!') - if decision_new: + if decision_old: return False return True @@ -216,6 +216,7 @@ def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[np def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph): result_mask = None + assert cls.check_concat(node, graph) if cls.check_concat(node, graph): result_mask = cls.generate_output_mask(node, graph)