Skip to content

Commit 07becc6

Browse files
Docstrings
1 parent ed67ed8 commit 07becc6

File tree

13 files changed

+86
-33
lines changed

13 files changed

+86
-33
lines changed

nncf/onnx/graph/nncf_graph_builder.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,14 @@ def get_bias_tensor_port_id(metatype: ONNXOpWithWeightsMetatype) -> Optional[int
119119
return None
120120

121121

122-
def _get_common_layer_attributes(node, metatype: ONNXOpMetatype):
122+
def _get_common_layer_attributes(node, metatype: ONNXOpMetatype) -> Optional[BaseLayerAttributes]:
123+
"""
124+
Returns layer-specific layer attributes for the given node.
125+
126+
:param node: Target Node to get layer attributes for.
127+
:param metatype: Target node metatype.
128+
:return: Target node layer attributes or None.
129+
"""
123130
if metatype == ONNXConcatMetatype:
124131
axis = [attr.i for attr in node.attribute if attr.name == "axis"][0]
125132
num_inputs = len(node.input)

nncf/openvino/graph/model_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,10 @@ def remove_fq_from_inputs(model: ov.Model, graph: NNCFGraph) -> ov.Model:
5353

5454

5555
def get_input_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
56+
"""
57+
Get all nodes from given nncf_graph that are identified as a input nodes.
58+
59+
:param nncf_graph: NNCFGraph to work with.
60+
:return: Target NNCFGraph input nodes.
61+
"""
5662
return list(set(nncf_graph.get_input_nodes()).upadte(nncf_graph.get_nodes_by_metatypes([OVReadValueMetatype])))

nncf/openvino/graph/nncf_graph_builder.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,31 @@ def create_nncf_graph(model: ov.Model) -> NNCFGraph:
189189
GraphConverter._add_edges_to_nncf_graph(model, nncf_graph)
190190
return nncf_graph
191191

192-
def _set_non_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph):
192+
def _set_non_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph) -> None:
193+
"""
194+
Sets layer attributes for a non weighted node.
195+
196+
:param node: Target node.
197+
:param metatype: Target node metatype.
198+
:param nncf_graph: NNCFGraph to work with.
199+
"""
193200
if metatype == OVConcatMetatype:
194201
nncf_node = nncf_graph.get_node_by_name(node.get_friendly_name())
195202
nncf_node.layer_attributes = OVLayerAttributes(
196203
{}, MultipleInputLayerAttributes(axis=node.get_axis(), num_inputs=len(node.inputs()))
197204
)
198205

199-
def _set_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph, visited: Set[str]):
206+
def _set_weighted_layer_attributes(
207+
node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph, visited: Set[str]
208+
) -> None:
209+
"""
210+
Sets layer attributes for a weighted node.
211+
212+
:param node: Target node.
213+
:param metatype: Target node metatype.
214+
:param nncf_graph: NNCFGraph to work with.
215+
:param visited: Set with node names that were already processed by the GraphConverter.
216+
"""
200217
const_attrs, act_attrs = {}, {}
201218
for inp in GraphConverter._filter_weight_input_ports(node.inputs(), metatype):
202219
inp_name = inp.get_source_output().get_node().get_friendly_name()

nncf/quantization/algorithms/accuracy_control/backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@ def get_quantizable_metatypes() -> List[OperatorMetatype]:
5454
@staticmethod
5555
@abstractmethod
5656
def get_graph_inputs(nncf_graph: NNCFGraph) -> List[NNCFNode]:
57-
pass
57+
"""
58+
Returns a list of NNCFNodes that are identified as an inputs.
59+
60+
:param nncf_graph: The NNCF graph.
61+
:return: List of NNCFNodes that are identified as an inputs.
62+
"""
5863

5964
@staticmethod
6065
@abstractmethod

nncf/quantization/algorithms/min_max/backend.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
6767
Property for the backend-specific Dropout metatypes.
6868
"""
6969

70-
@abstractmethod
71-
def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
72-
pass
73-
7470
@property
7571
@abstractmethod
7672
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
@@ -143,6 +139,16 @@ def create_quantizer_insertion_command(
143139
:return: Backend-specific TransformationCommand for the quantizer insertion operation.
144140
"""
145141

142+
@staticmethod
143+
@abstractmethod
144+
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
145+
"""
146+
Returns a list of NNCFNodes that are identified as an inputs.
147+
148+
:param nncf_graph: NNCFGraph to get input nodes from.
149+
:return: List of NNCFNodes that are identified as an inputs.
150+
"""
151+
146152
@staticmethod
147153
@abstractmethod
148154
def unify_statistics(statistics: List[MinMaxTensorStatistic]) -> MinMaxTensorStatistic:

nncf/quantization/algorithms/min_max/onnx_backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
6464
def add_metatypes(self) -> List[OperatorMetatype]:
6565
return [om.ONNXAddLayerMetatype]
6666

67-
def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
68-
return nncf_graph.get_input_nodes()
69-
7067
@property
7168
def group_conv_metatypes(self) -> List[OperatorMetatype]:
7269
return self.conv_metatypes
@@ -95,6 +92,10 @@ def hw_config(self) -> HWConfig:
9592
def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
9693
return DEFAULT_ONNX_QUANT_TRAIT_TO_OP_DICT
9794

95+
@staticmethod
96+
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
97+
return nncf_graph.get_input_nodes()
98+
9899
@staticmethod
99100
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint:
100101
return ONNXTargetPoint(target_type, target_node_name, port_id)

nncf/quantization/algorithms/min_max/openvino_backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
8787
def read_variable_metatypes(self) -> List[OperatorMetatype]:
8888
return [om.OVReadValueMetatype]
8989

90-
def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
91-
return get_input_nodes(nncf_graph)
92-
9390
@property
9491
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
9592
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
@@ -102,6 +99,10 @@ def hw_config(self) -> HWConfig:
10299
def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
103100
return DEFAULT_OV_QUANT_TRAIT_TO_OP_DICT
104101

102+
@staticmethod
103+
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
104+
return get_input_nodes(nncf_graph)
105+
105106
@staticmethod
106107
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
107108
return OVTargetPoint(target_type, target_node_name, port_id)

nncf/quantization/algorithms/min_max/torch_backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
7474
def read_variable_metatypes(self) -> List[OperatorMetatype]:
7575
return []
7676

77-
def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
78-
return get_inputs_for_graph_with_several_connected_components(nncf_graph)
79-
8077
@property
8178
def conv_metatypes(self) -> List[OperatorMetatype]:
8279
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
@@ -113,6 +110,10 @@ def hw_config(self) -> HWConfig:
113110
def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
114111
return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
115112

113+
@staticmethod
114+
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
115+
return get_inputs_for_graph_with_several_connected_components(nncf_graph)
116+
116117
@staticmethod
117118
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
118119
if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION:

nncf/quantization/passes.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ def transform_to_inference_graph(
2929
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
3030
3131
:param nncf_graph: NNCFGraph instance for the transformation.
32+
:param input_nodes: List of input nodes for the given NNCFGraph.
3233
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
3334
:param dropout_metatypes: List of backend-specific Dropout metatypes.
34-
:param read_variable_metatypes: List of backend-specific metatypes
35-
that also can be interpreted as inputs (ReadValue).
36-
:param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not.
3735
:return: NNCFGraph in the inference style.
3836
"""
3937
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes)
@@ -53,8 +51,7 @@ def remove_shapeof_subgraphs(
5351
5452
:param nncf_graph: NNCFGraph instance for the transformation.
5553
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
56-
:param read_variable_metatypes: List of backend-specific metatypes
57-
that also can be interpreted as inputs (ReadValue).
54+
:param input_nodes: List of input nodes for the given NNCFGraph.
5855
:return: NNCFGraph without ShapeOf subgraphs.
5956
"""
6057
nodes_to_drop = set()
@@ -149,8 +146,7 @@ def filter_constant_nodes(
149146
The traversing starts from the input nodes and nodes with weights.
150147
151148
:param nncf_graph: NNCFGraph instance for the transformation.
152-
:param read_variable_metatypes: List of backend-specific metatypes
153-
that also can be interpreted as inputs (ReadValue).
149+
:param input_nodes: List of input nodes for the given NNCFGraph.
154150
:return: NNCFGraph without Constant nodes.
155151
"""
156152
if not input_nodes:

nncf/torch/graph/graph.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,15 @@ def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope:
7171
return matches[0]
7272

7373

74-
def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGraph):
74+
def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGraph) -> List[NNCFNode]:
75+
"""
76+
Returns a list of NNCFNodes that are identified as an inputs. Requires MultipleInputLayerAttributes
77+
for nodes with several inputs and right `input_edges_num_expected` parameter setted for
78+
nncf nodes metatypes.
79+
80+
:param nncf_graph: NNCFGraph to get input nodes from.
81+
:return: List of NNCFNodes that are identified as an inputs.
82+
"""
7583
input_nodes = set()
7684
for node in nncf_graph.get_all_nodes():
7785
input_edges_num_expected = None
@@ -84,6 +92,8 @@ def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGra
8492
if input_edges_num_expected:
8593
input_edges = nncf_graph.get_input_edges(node)
8694
if len(input_edges) < input_edges_num_expected:
95+
# If node has missed input edges we assume this node is an input node
96+
# that was disconected from an activation input.
8797
input_nodes.add(node)
8898
input_nodes.update(nncf_graph.get_input_nodes())
8999
return list(input_nodes)

0 commit comments

Comments
 (0)