Skip to content

Commit 03e3faa

Browse files
Merge branch 'dl/conv_layer_attrs_update' into dl/quantization/passes_for_splitted_graphs
2 parents 6bc2f86 + 3fb5f9a commit 03e3faa

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

nncf/openvino/graph/layer_attributes.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,7 @@ def get_backend_agnostic_attributes(self):
8080
]
8181

8282

83-
OVConvLayout = List[OVLayoutElem]
84-
85-
86-
def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
83+
def get_conv_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
8784
"""
8885
Calculates weights layout for a target convolution node.
8986
@@ -97,7 +94,7 @@ def get_conv_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
9794
)
9895

9996

100-
def get_linear_weights_layout_from_node(node: NNCFNode) -> OVConvLayout:
97+
def get_linear_weights_layout_from_node(node: NNCFNode) -> List[OVLayoutElem]:
10198
"""
10299
Calculates weights layout for a target linear node.
103100
@@ -126,7 +123,7 @@ def _get_constant_port_id_from_layer_attributes(layer_attributes: OVLayerAttribu
126123
return port_ids[0]
127124

128125

129-
def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> OVConvLayout:
126+
def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int, ...]) -> List[OVLayoutElem]:
130127
"""
131128
Calculates weights layout for a target convolution node.
132129
@@ -140,7 +137,7 @@ def get_conv_weights_layout(ov_metatype: OVOpMetatype, weights_shape: Tuple[int,
140137
return tuple(weights_layout)
141138

142139

143-
def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> OVConvLayout:
140+
def get_linear_weights_layout(weights_shape: Tuple[int, ...], transpose: bool, port_id: int) -> List[OVLayoutElem]:
144141
"""
145142
Calculates weights layout for a target linear node.
146143

nncf/openvino/graph/metatypes/openvino_metatypes.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry
2121
from nncf.common.graph.operator_metatypes import UnknownMetatype
2222
from nncf.common.hardware.opset import HWConfigOpName
23+
from nncf.openvino.graph.layout import OVLayoutElem
2324

2425
OV_OPERATOR_METATYPES = OperatorMetatypeRegistry("openvino_operator_metatypes")
2526

@@ -58,7 +59,8 @@ class OVConvolutionMetatype(OVOpMetatype):
5859
name = "ConvOp"
5960
op_names = ["Convolution"]
6061
hw_config_names = [HWConfigOpName.CONVOLUTION]
61-
const_channel_axis = [0] # const layout: [C_OUT, C_IN, Z, Y, X]
62+
const_channel_axis = [0]
63+
const_layout = [OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
6264
output_channel_axis = 1
6365

6466

@@ -67,7 +69,8 @@ class OVConvolutionBackpropDataMetatype(OVOpMetatype):
6769
name = "ConvBackpropDataOp"
6870
op_names = ["ConvolutionBackpropData"]
6971
hw_config_names = [HWConfigOpName.CONVOLUTION]
70-
const_channel_axis = [1] # const layout: [C_IN, C_OUT, Z, Y, X]
72+
const_channel_axis = [1]
73+
const_layout = [OVLayoutElem.C_IN, OVLayoutElem.C_OUT]
7174
output_channel_axis = 1
7275

7376

@@ -76,7 +79,8 @@ class OVDepthwiseConvolutionMetatype(OVOpMetatype):
7679
name = "DepthwiseConvolutionOp"
7780
op_names = ["GroupConvolution"]
7881
hw_config_names = [HWConfigOpName.DEPTHWISECONVOLUTION]
79-
const_channel_axis = [0, 1] # const layout: [GROUPS, C_OUT / GROUPS, C_IN / GROUPS, Z, Y, X]
82+
const_channel_axis = [0, 1]
83+
const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
8084
output_channel_axis = 1
8185

8286
@classmethod
@@ -90,7 +94,8 @@ class OVGroupConvolutionMetatype(OVOpMetatype):
9094
op_names = ["GroupConvolution"]
9195
hw_config_names = [HWConfigOpName.CONVOLUTION]
9296
subtypes = [OVDepthwiseConvolutionMetatype]
93-
const_channel_axis = [0, 1] # const layout: [GROUPS, C_OUT / GROUPS, C_IN / GROUPS, Z, Y, X]
97+
const_channel_axis = [0, 1]
98+
const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_OUT, OVLayoutElem.C_IN]
9499
output_channel_axis = 1
95100

96101

@@ -99,7 +104,8 @@ class OVGroupConvolutionBackpropDataMetatype(OVOpMetatype):
99104
name = "GroupConvolutionBackpropDataOp"
100105
op_names = ["GroupConvolutionBackpropData"]
101106
hw_config_names = [HWConfigOpName.CONVOLUTION]
102-
const_channel_axis = [0, 2] # const layout: [GROUPS, C_IN / GROUPS, C_OUT / GROUPS, Z, Y, X]
107+
const_channel_axis = [0, 2]
108+
const_layout = [OVLayoutElem.GROUPS, OVLayoutElem.C_IN, OVLayoutElem.C_OUT]
103109
output_channel_axis = 1
104110

105111

0 commit comments

Comments
 (0)