Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
975681a
Refactor export helpers
daniil-lyakhov Sep 30, 2021
6f1fd11
Remove flatten and reshape ops to clean up pr
daniil-lyakhov Oct 1, 2021
5d8150f
Revert PTDefaultMetaOp
daniil-lyakhov Oct 1, 2021
c67e5c2
WIP export helper tests
daniil-lyakhov Oct 1, 2021
3de750b
Add axis attribute for concat layers
daniil-lyakhov Oct 4, 2021
b1a53b8
Add concat tests for common implementation
daniil-lyakhov Oct 4, 2021
e1e1b51
common export_helpers tests
daniil-lyakhov Oct 4, 2021
c2c26fb
Fix diamond inheritance
daniil-lyakhov Oct 5, 2021
676e9d6
Unify export helpers
daniil-lyakhov Oct 5, 2021
4943569
Fix concat axis problems
daniil-lyakhov Oct 5, 2021
075e34d
Fix naming
daniil-lyakhov Oct 5, 2021
b94ef61
Fix concat axis calculation
daniil-lyakhov Oct 5, 2021
ab90d34
Fix pylint
daniil-lyakhov Oct 5, 2021
f428fb4
Process stack operation
daniil-lyakhov Oct 5, 2021
f9fe604
Add test case with concat zero axis
daniil-lyakhov Oct 5, 2021
d4f940b
Add concat test case with different branch channel dim
daniil-lyakhov Oct 6, 2021
770ef1d
Make torch graph test more strict / add stack nodes to the test model
daniil-lyakhov Oct 7, 2021
738ada8
Rename export_helpers.py
daniil-lyakhov Oct 8, 2021
4c4f484
Introduce NNCFTensor to separate pruning operations from framework te…
daniil-lyakhov Oct 8, 2021
de4f792
Remove `check_concat` method
daniil-lyakhov Oct 8, 2021
cd1483a
Minor fix / add comments
daniil-lyakhov Oct 11, 2021
a8e033c
Apply comments
daniil-lyakhov Oct 13, 2021
60dfac8
Fix typo
daniil-lyakhov Oct 13, 2021
5c46689
Make mask_propagation for group_norm consistent even when mask propag…
daniil-lyakhov Oct 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def layer_name(self) -> LayerName:
def layer_attributes(self) -> BaseLayerAttributes:
return self.data.get(NNCFGraph.LAYER_ATTRIBUTES)

@layer_attributes.setter
def layer_attributes(self, data: Any) -> None:
self.data[NNCFGraph.LAYER_ATTRIBUTES] = data

@property
def ignored_algorithms(self) -> List[str]:
return self.data.get(NNCFGraph.IGNORED_ALGOS_ATTR, [])
Expand Down
25 changes: 20 additions & 5 deletions nncf/common/graph/layer_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from abc import ABC
from abc import abstractmethod
from enum import Enum
from typing import List
from typing import Tuple
from typing import List, Tuple, Any


class Dtype(Enum):
Expand All @@ -29,15 +28,30 @@ class BaseLayerAttributes(ABC):
"""


class MultipleInputLayerAttributes(BaseLayerAttributes):
"""
Represents a layer with multiple inputs.
"""

def __init__(self,
axis: int):
self.axis = axis

def __eq__(self, other: Any):
return isinstance(other, MultipleInputLayerAttributes) \
and self.axis == other.axis


class WeightedLayerAttributes(BaseLayerAttributes):
"""
Represents a layer with weights.
"""

def __init__(self, weight_requires_grad: bool, dtype: Dtype = Dtype.FLOAT):
self.weight_requires_grad = weight_requires_grad
self.dtype = dtype

def __eq__(self, other):
def __eq__(self, other: Any):
return isinstance(other, WeightedLayerAttributes) \
and self.weight_requires_grad == other.weight_requires_grad

Expand All @@ -59,6 +73,7 @@ class GenericWeightedLayerAttributes(WeightedLayerAttributes):
Represents a weighted layer for which there is no information ahead of time
of the exact meaning of the weight indices.
"""

def __init__(self, weight_requires_grad: bool, weight_shape: List[int],
filter_dimension_idx: int = 0):
super().__init__(weight_requires_grad)
Expand Down Expand Up @@ -112,7 +127,7 @@ def __init__(self,
self.transpose = transpose
self.padding_values = padding_values

def __eq__(self, other):
def __eq__(self, other: Any):
return isinstance(other, ConvolutionLayerAttributes) \
and super().__eq__(other) \
and self.in_channels == other.in_channels \
Expand Down Expand Up @@ -148,7 +163,7 @@ def __init__(self,
self.num_channels = num_channels
self.num_groups = num_groups

def __eq__(self, other):
def __eq__(self, other: Any):
return isinstance(other, GroupNormLayerAttributes) \
and super().__eq__(other) \
and self.num_channels == other.num_channels \
Expand Down
43 changes: 43 additions & 0 deletions nncf/common/graph/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Copyright (c) 2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import List

from nncf.common.utils.logger import logger


def get_concat_axis(input_shapes: List[List[int]], output_shapes: List[List[int]]) -> int:
"""
Returns concatenation axis by given input and output shape of concat node.

:param input_shapes: Input_shapes of given concat node.
:param output_shapes: Input_shapes of given concat node.
:returns: Concatenation axis of given concat node.
"""
axis = None
none_dim = None
for idx, (dim_in, dim_out) in enumerate(zip(input_shapes[0], output_shapes[0])):
if dim_in != dim_out:
axis = idx
break
if dim_in is None:
none_dim = idx

if axis is None:
if none_dim is None:
axis = -1
logger.warning('Identity concat node detected')
else:
axis = none_dim

return axis
52 changes: 0 additions & 52 deletions nncf/common/pruning/export_helpers.py

This file was deleted.

33 changes: 2 additions & 31 deletions nncf/common/pruning/mask_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,10 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List
from typing import Union
from typing import TypeVar

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.pruning.export_helpers import DefaultMetaOp
from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry

TensorType = TypeVar('TensorType')
from nncf.common.pruning.operations import BasePruningOp


class MaskPropagationAlgorithm:
Expand All @@ -40,7 +34,7 @@ def __init__(self, graph: NNCFGraph, pruning_operator_metatypes: PruningOperatio
self._graph = graph
self._pruning_operator_metatypes = pruning_operator_metatypes

def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp:
def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
"""
Returns class of metaop that corresponds to `type_name` type.

Expand All @@ -60,26 +54,3 @@ def mask_propagation(self):
for node in self._graph.topological_sort():
cls = self.get_meta_operation_by_type_name(node.node_type)
cls.mask_propagation(node, self._graph)


def get_input_masks(node: NNCFNode, graph: NNCFGraph) -> List[Union[TensorType, None]]:
"""
Returns input masks for all inputs of nx_node.

:return: Input masks.
"""
input_masks = [input_node.data['output_mask'] for input_node in graph.get_previous_nodes(node)]
return input_masks


def identity_mask_propagation(node: NNCFNode, graph: NNCFGraph):
"""
Propagates input mask through nx_node.
"""
input_masks = get_input_masks(node, graph)
if not input_masks:
# In case for disconnected NNCFGraph
input_masks = [None]
assert len(input_masks) == 1
node.data['input_masks'] = input_masks
node.data['output_mask'] = input_masks[0]
4 changes: 2 additions & 2 deletions nncf/common/pruning/model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from nncf.common.graph import NNCFNode
from nncf.common.pruning.clusterization import Cluster
from nncf.common.pruning.clusterization import Clusterization
from nncf.common.pruning.export_helpers import DefaultMetaOp
from nncf.common.pruning.operations import BasePruningOp
from nncf.common.pruning.utils import find_next_nodes_not_of_types
from nncf.common.pruning.utils import PruningOperationsMetatypeRegistry

Expand Down Expand Up @@ -143,7 +143,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool:
"""
return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases()

def get_meta_operation_by_type_name(self, type_name: str) -> DefaultMetaOp:
def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
"""
Returns class of metaop that corresponds to `type_name` type.

Expand Down
Loading