Skip to content

Commit 999c5e9

Browse files
[PT] FBC for linear (#3808)
### Changes Add support linear operation for FBC ### Related tickets CVS-111111 ### Tests https://github.com/openvinotoolkit/nncf/actions/runs/20462273440 manual/job/post_training_quantization/766/
1 parent ebed81c commit 999c5e9

File tree

7 files changed

+92
-21
lines changed

7 files changed

+92
-21
lines changed

src/nncf/torch/function_hook/extractor.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def extract_conv(
127127
"""
128128
Extracts a convolutional layer from an NNCF graph and constructs an ExtractedFunc module.
129129
130-
:param model: The NNCF network containing the layer.
130+
:param model: The nn.Module containing the layer.
131131
:param graph: The NNCF graph.
132132
:param input_nodes: The name of input node.
133133
:param output_nodes: The name of output node.
@@ -145,13 +145,13 @@ def extract_conv(
145145

146146
weight_node = get_const_node(input_node, 1, graph)
147147
if weight_node is None:
148-
msg = "Weight node not found for {input_node}"
148+
msg = f"Weight node not found for {input_node}"
149149
raise nncf.InternalError(msg)
150150
weight = get_const_data(weight_node, model)
151151

152152
hook_storage = get_hook_storage(model)
153153
with torch.no_grad():
154-
# Calculate weight after execution all hook fro weight data
154+
# Calculate weight after execution all hook for weight data
155155
weight = hook_storage.execute_post_function_hooks(weight_node.node_name, 0, weight)
156156
weight = hook_storage.execute_pre_function_hooks(input_node.node_name, 1, weight)
157157

@@ -189,17 +189,66 @@ def extract_conv(
189189
return nn.Sequential(conv_module, bn_module)
190190

191191

192+
def extract_linear(
193+
model: nn.Module,
194+
graph: PTNNCFGraph,
195+
input_node: NNCFNode,
196+
output_node: NNCFNode,
197+
) -> ExtractedFunc:
198+
"""
199+
Extracts a linear layer from an NNCF graph and constructs an ExtractedFunc module.
200+
201+
:param model: The nn.Module containing the layer.
202+
:param graph: The NNCF graph.
203+
:param input_node: The name of input node.
204+
:param output_node: The name of output node.
205+
:return: The extracted linear layer as an ExtractedFunc module.
206+
"""
207+
if input_node != output_node:
208+
msg = "Only one input and output node supported."
209+
raise nncf.InternalError(msg)
210+
211+
layer_attrs = input_node.layer_attributes
212+
213+
if not isinstance(layer_attrs, PT2OpLayerAttributes):
214+
msg = f"Expected PT2OpLayerAttributes for input_node.layer_attributes, actual: {type(layer_attrs)}"
215+
raise nncf.InternalError(msg)
216+
217+
weight_node = get_const_node(input_node, 1, graph)
218+
if weight_node is None:
219+
msg = f"Weight node not found for {input_node}"
220+
raise nncf.InternalError(msg)
221+
weight = get_const_data(weight_node, model)
222+
223+
hook_storage = get_hook_storage(model)
224+
with torch.no_grad():
225+
# Calculate weight after execution all hook for weight data
226+
weight = hook_storage.execute_post_function_hooks(weight_node.node_name, 0, weight)
227+
weight = hook_storage.execute_pre_function_hooks(input_node.node_name, 1, weight)
228+
229+
bias_node = get_const_node(input_node, 2, graph)
230+
bias = get_const_data(bias_node, model) if bias_node is not None else None
231+
232+
layer_kwarg = {
233+
"weight": weight,
234+
"bias": bias,
235+
}
236+
linear_module = ExtractedFunc(layer_attrs.func, layer_kwarg)
237+
return linear_module
238+
239+
192240
def extract_model(
193241
model: nn.Module, graph: PTNNCFGraph, input_nodes: list[str], output_nodes: list[str]
194242
) -> Optional[nn.Module]:
195243
"""
196-
Extracts a submodule from a given NNCF network containing only the nodes from the input to the output node.
244+
Extracts a submodule from a given nn.Module containing only the nodes from the input to the output node.
197245
198246
Supported subgraph:
199247
- Conv
200248
- Conv + BatchNorm
249+
- Linear
201250
202-
:param model: The NNCF network to extract the submodule from.
251+
:param model: The nn.Module to extract the submodule from.
203252
:param input_nodes: List containing names of the input nodes for the submodule.
204253
:param output_nodes: List containing names of the output nodes for the submodule.
205254
:return: An nn.Module containing the extracted submodel, or None if extraction is not supported.
@@ -214,5 +263,8 @@ def extract_model(
214263
if input_node.metatype in CONV_METATYPES:
215264
return extract_conv(model, graph, input_node, output_node)
216265

217-
nncf_logger.debug(f"Can`t extract module for {input_node.node_name}")
266+
if input_node.metatype is om.PTLinearMetatype:
267+
return extract_linear(model, graph, input_node, output_node)
268+
269+
nncf_logger.debug(f"Can not extract module for {input_node.node_name}")
218270
return None

src/nncf/torch/graph/operator_metatypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ class PTLinearMetatype(PTOperatorMetatype):
287287
output_channel_axis = -1
288288
num_expected_input_edges = 2
289289
weight_port_ids = [1]
290+
bias_port_id = 2
290291

291292

292293
@PT_OPERATOR_METATYPES.register()

src/nncf/torch/model_graph_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
om.PTConvTranspose3dMetatype,
3636
]
3737

38-
OPERATORS_WITH_BIAS_METATYPES = CONV_META_TYPES
38+
OPERATORS_WITH_BIAS_METATYPES = CONV_META_TYPES + [om.PTLinearMetatype]
3939
CONV_FUSED_META_TYPES = [om.PTBatchNormMetatype]
4040

4141

tests/cross_fw/test_templates/helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,13 @@ def forward(self, x):
163163

164164

165165
class FCTestModel(nn.Module):
166-
INPUT_SIZE = [1, 1, 4, 4]
166+
INPUT_SIZE = [1, 1, 3, 3]
167167

168168
def __init__(self):
169169
super().__init__()
170-
self.fc = nn.Linear(4, 2)
171-
self.fc.weight.data = torch.Tensor([[0.1, 0.2, 0.3, 0.2], [0.3, -0.1, 0.2, 0.4]])
172-
self.fc.bias.data = torch.Tensor([1.0, 1.1])
170+
self.fc = nn.Linear(3, 2)
171+
self.fc.weight.data = torch.Tensor([[0.1, 0.2, 0.3], [0.3, -0.1, 0.2]])
172+
self.fc.bias.data = torch.Tensor([1.0, 2.0])
173173

174174
def forward(self, x):
175175
x = self.fc(x)

tests/cross_fw/test_templates/test_fast_bias_correction.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# limitations under the License.
1111

1212
from abc import abstractmethod
13+
from dataclasses import dataclass
14+
from pathlib import Path
1315
from typing import TypeVar
1416

1517
import pytest
@@ -22,12 +24,22 @@
2224
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
2325
from tests.cross_fw.test_templates.helpers import ConvBNTestModel
2426
from tests.cross_fw.test_templates.helpers import ConvTestModel
27+
from tests.cross_fw.test_templates.helpers import FCTestModel
2528
from tests.cross_fw.test_templates.helpers import get_static_dataset
2629

2730
TModel = TypeVar("TModel")
2831
TTensor = TypeVar("TTensor")
2932

3033

34+
@dataclass
35+
class TestCase:
36+
model_cls: type
37+
ref_bias: list
38+
39+
def __str__(self):
40+
return self.model_cls.__name__
41+
42+
3143
class TemplateTestFBCAlgorithm:
3244
@staticmethod
3345
@abstractmethod
@@ -104,18 +116,19 @@ def get_quantization_algorithm():
104116
)
105117

106118
@pytest.mark.parametrize(
107-
"model_cls, ref_bias",
119+
"params",
108120
(
109-
(ConvTestModel, [0.0288348, 1.0838453]),
110-
(ConvBNTestModel, [0.08396978, 1.1676897]),
121+
TestCase(ConvTestModel, [0.0288348, 1.0838453]),
122+
TestCase(ConvBNTestModel, [0.08396978, 1.1676897]),
123+
TestCase(FCTestModel, [0.9999, 1.9989]),
111124
),
125+
ids=str,
112126
)
113-
def test_update_bias(self, model_cls, ref_bias, tmpdir):
114-
model = self.backend_specific_model(model_cls(), tmpdir)
115-
dataset = get_static_dataset(model_cls.INPUT_SIZE, self.get_transform_fn(), self.fn_to_type)
127+
def test_update_bias(self, params: TestCase, tmpdir: Path):
128+
model = self.backend_specific_model(params.model_cls(), tmpdir)
129+
dataset = get_static_dataset(params.model_cls.INPUT_SIZE, self.get_transform_fn(), self.fn_to_type)
116130

117131
quantization_algorithm = self.get_quantization_algorithm()
118132
graph = NNCFGraphFactory.create(model)
119133
quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset)
120-
121-
self.check_bias(quantized_model, ref_bias)
134+
self.check_bias(quantized_model, params.ref_bias)

tests/torch/test_model_graph_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_get_potential_fused_node(self, model_desc):
111111
"ConvBiasBNTestModel": True,
112112
"ConvBNTestModel": True,
113113
"ConvTestModel": True,
114-
"FCTestModel": False,
114+
"FCTestModel": True,
115115
"MultipleConvTestModel": True,
116116
"CustomConvTestModel": True,
117117
"CustomConvBNTestModel": True,
@@ -152,7 +152,7 @@ def test_get_const_node(self, model_desc, port_id):
152152
[[[[0.1000, -2.0000], [1.0000, 0.1000]]], [[[0.1000, 2.0000], [-1.0000, 0.1000]]]],
153153
[0.1000, 1.0000],
154154
),
155-
"FCTestModel": ([[0.1000, 0.2000, 0.3000, 0.2000], [0.3000, -0.1000, 0.2000, 0.4000]], [1.0000, 1.1000]),
155+
"FCTestModel": ([[0.1000, 0.2000, 0.3000], [0.3000, -0.1000, 0.2000]], [1.0000, 2.0000]),
156156
"MultipleConvTestModel": (
157157
[[[[-2.4661, 0.3623], [0.3765, -0.1808]]], [[[0.3930, 0.4327], [-1.3627, 1.3564]]]],
158158
[0.6688, -0.7077],

tests/torch2/function_hook/test_extractor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848
"conv/conv2d/0",
4949
"conv/conv2d/0",
5050
),
51+
(
52+
helpers.FCTestModel,
53+
"fc/linear/0",
54+
"fc/linear/0",
55+
),
5156
)
5257

5358

0 commit comments

Comments
 (0)