Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def layer_name(self) -> LayerName:
def layer_attributes(self) -> BaseLayerAttributes:
return self.data.get(NNCFGraph.LAYER_ATTRIBUTES)

@layer_attributes.setter
def layer_attributes(self, data) -> None:
self.data[NNCFGraph.LAYER_ATTRIBUTES] = data

@property
def ignored_algorithms(self) -> List[str]:
return self.data.get(NNCFGraph.IGNORED_ALGOS_ATTR, [])
Expand Down
13 changes: 13 additions & 0 deletions nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ class BaseLayerAttributes(ABC):
"""


class MultipleInputLayerAttributes(BaseLayerAttributes):
"""
Represents a layer with multiple inputs.
"""
def __init__(self,
axis: int):
self.axis = axis

def __eq__(self, other):
return isinstance(other, MultipleInputLayerAttributes) \
and self.axis == other.axis


class WeightedLayerAttributes(BaseLayerAttributes):
"""
Represents a layer with weights.
Expand Down
43 changes: 43 additions & 0 deletions nncf/common/graph/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Copyright (c) 2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import List

from nncf.common.utils.logger import logger


def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int]]) -> int:
"""
Returns concatenation axis by given input and output shape of concat node.

:param input_shapes: Input_shapes of given concat node.
:param output_shapes: Input_shapes of given concat node.
:returns: Concatenation axis of given concat node.
"""
axis = None
none_dim = None
for idx, (dim_in, dim_out) in enumerate(zip(input_shapes[0], output_shapes[0])):
if dim_in != dim_out:
axis = idx
break
if dim_in is None:
none_dim = idx

if axis is None:
if none_dim is None:
axis = -1
logger.warning('Identity concat node detected')
else:
axis = none_dim

return axis
53 changes: 53 additions & 0 deletions nncf/common/pruning/default_pruning_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Copyright (c) 2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.graph import NNCFGraph


class DefaultPruningOp:
"""
Determines meta operations which aggregate operations having common
properties of interaction with pruning masks
"""

subtypes = []
additional_types = []

@classmethod
def accept_pruned_input(cls, node: NNCFNode):
"""
:return: accept_pruned_input - can this operation work with pruned input or not
"""
raise NotImplementedError

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
"""
Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any).

:param node: The graph node to propagate mask through it
:param graph: The model graph to prune
"""
raise NotImplementedError

@classmethod
def get_all_op_aliases(cls):
"""
:return: list of all aliases of types in metatype
"""
op_types = []
for subtype in cls.subtypes:
op_types.extend(subtype.get_all_aliases())
op_types = list(set(op_types)) + cls.additional_types
return op_types
245 changes: 223 additions & 22 deletions nncf/common/pruning/export_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,244 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from typing import Union, List

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.pruning.utils import is_grouped_conv
from nncf.common.pruning.utils import get_sources_of_node
from nncf.common.pruning.utils import is_depthwise_conv
from nncf.common.graph.layer_attributes import GroupNormLayerAttributes
from nncf.common.pruning.mask_propagation import identity_mask_propagation
from nncf.common.pruning.mask_propagation import get_input_masks
from nncf.common.pruning.default_pruning_op import DefaultPruningOp


class DefaultMetaOp:
"""
Determines meta operations which aggregate operations having common
properties of interaction with pruning masks
"""
class InputPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
return False

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
node.data['output_mask'] = None

subtypes = []
additional_types = []

class OutputPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
"""
:return: accept_pruned_input - can this operation work with pruned input or not
"""
raise NotImplementedError
return True

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
node.data['output_mask'] = None


class IdentityMaskForwardPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
return True

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
identity_mask_propagation(node, graph)


class ConvolutionPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
accept_pruned_input = True
if is_grouped_conv(node):
if not is_depthwise_conv(node):
accept_pruned_input = False
return accept_pruned_input

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
input_masks = get_input_masks(node, graph)
output_mask = node.data.get('output_mask', None)

if is_grouped_conv(node):
output_mask = None
if is_depthwise_conv(node):
output_mask = input_masks[0]

node.data['output_mask'] = output_mask


class TransposeConvolutionPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
accept_pruned_input = True
if is_grouped_conv(node):
if not is_depthwise_conv(node):
accept_pruned_input = False
return accept_pruned_input

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
input_masks = get_input_masks(node, graph)
output_mask = node.data.get('output_mask', None)

# In case of group convs we can't prune by output filters
if is_grouped_conv(node):
output_mask = None
if is_depthwise_conv(node):
output_mask = input_masks[0]

node.data['output_mask'] = output_mask


class BatchNormPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
return True

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
identity_mask_propagation(node, graph)


class GroupNormPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
# For Instance Normalization
return isinstance(node.layer_attributes, GroupNormLayerAttributes) \
and node.layer_attributes.num_groups == node.layer_attributes.num_channels

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
identity_mask_propagation(node, graph)


class ConcatPruningOp(DefaultPruningOp):
ConvolutionOp = None # type: ConvolutionPruningOp
StopMaskForwardOp = None # type: StopMaskForwardPruningOp
InputOp = None # type: InputPruningOp

@classmethod
def accept_pruned_input(cls, node: NNCFNode):
return True

@classmethod
def check_concat(cls, node: NNCFNode, graph: NNCFGraph) -> bool:
"""
Propagates the pruning mask through a node using pruning masks of all inputs and the current node (if any).
Return whether all input sources of node is convolutions or not.

:param node: The graph node to propagate mask through it
:param graph: The model graph to prune
:param node: Node to determine it's sources
:param graph: NNCF graph to work with
:return: True if all input sources of node is convolutions
"""
raise NotImplementedError

for input_node in graph.get_previous_nodes(node):
# If input has mask -> it went from convolution (source of this node is a convolution)
node_has_mask = False
if input_node.data.get('output_mask', None) is not None:
node_has_mask = True

source_nodes = get_sources_of_node(input_node, graph, cls.ConvolutionOp.get_all_op_aliases() +
cls.StopMaskForwardOp.get_all_op_aliases() +
cls.InputOp.get_all_op_aliases())

source_types_old = [node.node_type for node in source_nodes]
sources_types_new = source_types_old + [input_node.node_type]

decision_old_on_sources = any(t in source_types_old for t in cls.StopMaskForwardOp.get_all_op_aliases())
decision_old = decision_old_on_sources and node_has_mask

decision_new_on_sources = any(t in sources_types_new for t in cls.StopMaskForwardOp.get_all_op_aliases())
decision_new = decision_new_on_sources and not node_has_mask

if decision_new != decision_old:
is_on_sources_equal = decision_new_on_sources == decision_old_on_sources
if not is_on_sources_equal:
raise ValueError('ALERT')

print(f'is_on_sources_equal = {is_on_sources_equal}')
print('behaviour changed!!!')
if decision_old:
return False
return True

@classmethod
def _get_unit_mask(cls, dim, device):
return np.ones(dim)

@classmethod
def get_all_op_aliases(cls):
def _get_masks_device(cls, input_masks):
return None

@classmethod
def _concat_masks(cls, filled_input_masks):
return np.concatenate(filled_input_masks, 0)

@classmethod
def generate_output_mask(cls, node: NNCFNode, graph: NNCFGraph) -> Union[List[np.array], None]:
"""
:return: list of all aliases of types in metatype
Generate output mask from input masks with all None replaced by identity masks.
If all input masks is None return None.

:param node: Node to determine it's sources.
:param graph: NNCF graph to work with.
:return: Filled input masks.
"""
op_types = []
for subtype in cls.subtypes:
op_types.extend(subtype.get_all_aliases())
op_types = list(set(op_types)) + cls.additional_types
return op_types
input_edges = graph.get_input_edges(node)
previous_nodes = [edge.from_node for edge in input_edges]
input_masks = [input_node.data['output_mask'] for input_node in previous_nodes]

if all(mask is None for mask in input_masks):
return None

filled_input_masks = []
for i, mask in enumerate(input_masks):
if mask is None:
concat_axis = node.layer_attributes.axis
concat_dim = input_edges[i].tensor_shape[concat_axis]
device = cls._get_masks_device(input_masks)
mask = cls._get_unit_mask(concat_dim, device)
filled_input_masks.append(mask)
result_mask = cls._concat_masks(filled_input_masks)
return result_mask

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
result_mask = None

assert cls.check_concat(node, graph)
if cls.check_concat(node, graph):
result_mask = cls.generate_output_mask(node, graph)

node.data['output_mask'] = result_mask


class ElementwisePruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
return True

@classmethod
def _assert_input_masks_close(cls, input_masks):
for input_mask in input_masks[1:]:
np.testing.assert_allclose(input_masks[0], input_mask)

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
input_masks = get_input_masks(node, graph)

node.data['input_masks'] = input_masks
if input_masks[0] is not None:
cls._assert_input_masks_close(input_masks)
node.data['output_mask'] = input_masks[0]


class StopMaskForwardPruningOp(DefaultPruningOp):
@classmethod
def accept_pruned_input(cls, node: NNCFNode):
return False

@classmethod
def mask_propagation(cls, node: NNCFNode, graph: NNCFGraph):
node.data['output_mask'] = None
Loading