Skip to content

Commit ec182c1

Browse files
Init Torch.fx BiasCorrection
1 parent 3b1e7f0 commit ec182c1

File tree

4 files changed

+198
-23
lines changed

4 files changed

+198
-23
lines changed

nncf/experimental/torch_fx/model_transformer.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717

1818
import torch
1919
import torch.fx
20-
21-
# from torch import Tensor
22-
# from torch import nn
2320
from torch.ao.quantization.fx.utils import create_getattr_from_value
2421
from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
2522
from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
@@ -28,6 +25,7 @@
2825
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
2926
from torch.fx import GraphModule
3027
from torch.fx.passes.infra.pass_manager import PassManager
28+
from torch.fx.passes.split_utils import split_by_tags
3129

3230
from nncf.common.graph.model_transformer import ModelTransformer
3331

@@ -37,6 +35,10 @@
3735
from nncf.common.graph.transformations.commands import TargetType
3836
from nncf.common.graph.transformations.commands import TransformationPriority
3937
from nncf.common.graph.transformations.commands import TransformationType
38+
39+
# from torch import Tensor
40+
# from torch import nn
41+
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
4042
from nncf.torch.graph.transformations.commands import PTTargetPoint
4143

4244
# from nncf.torch.graph.transformations.commands import PTTargetPoint
@@ -80,13 +82,16 @@ class FXModelTransformer(ModelTransformer):
8082
Applies transformations upon Torch FX model.
8183
"""
8284

85+
# TODO: manage priorities of transformations
86+
8387
def __init__(self, model: torch.fx.GraphModule):
8488
super().__init__(model)
8589

8690
self._command_transformation_ordered_pairs = [
8791
# TODO: Move the module insertion command to a transformation
8892
(FXApplyTransformationCommand, self._apply_transformation),
8993
(FXModuleInsertionCommand, self._apply_module_insertion),
94+
(PTModelExtractionCommand, self._apply_model_extraction),
9095
]
9196

9297
def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule:
@@ -107,6 +112,34 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G
107112
model.recompile()
108113
return model
109114

115+
@staticmethod
116+
def _apply_model_extraction(
117+
model: torch.fx.GraphModule,
118+
transformations: List[PTModelExtractionCommand],
119+
) -> torch.fx.GraphModule:
120+
transformation = transformations[-1]
121+
assert len(transformation.input_node_names) == 1
122+
assert transformation.input_node_names == transformation.output_node_names
123+
node_name = transformation.input_node_names[0]
124+
125+
tags = ["before", "extracted", "after"]
126+
i = 0
127+
for node in model.graph.nodes:
128+
if node.name == node_name:
129+
node.tag = tags[1]
130+
weights = [node.all_input_nodes[1]]
131+
while weights:
132+
w_node = weights.pop()
133+
assert w_node.tag in tags[0:2]
134+
w_node.tag = tags[1]
135+
weights.extend(w_node.all_input_nodes)
136+
i = 2
137+
continue
138+
node.tag = tags[i]
139+
140+
splitted_gm = split_by_tags(model, tags)
141+
return splitted_gm.extracted
142+
110143
@staticmethod
111144
def _apply_module_insertion(
112145
model: torch.fx.GraphModule,
@@ -141,15 +174,16 @@ def _apply_module_insertion(
141174
return model
142175

143176
@staticmethod
144-
def _get_grah_node_by_name(graph, name):
177+
def get_graph_node_by_name(graph, name):
145178
for node in graph.nodes:
146179
if node.name == name:
147180
return node
181+
raise RuntimeError(f"Node with name {name} is not found")
148182

149183
@staticmethod
150184
def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint):
151185
target_type = target_point.target_type
152-
target_node = FXModelTransformer._get_grah_node_by_name(graph, target_point.target_node_name)
186+
target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name)
153187
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
154188
target_node = target_node.all_input_nodes[target_point.input_port_id]
155189
elif target_type == TargetType.OPERATOR_POST_HOOK:

nncf/experimental/torch_fx/transformations.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.ao.quantization.fx.utils import create_getattr_from_value
1717
from torch.quantization.fake_quantize import FakeQuantize
1818

19+
from nncf.common.graph.graph import NNCFNode
1920
from nncf.common.graph.transformations.commands import TargetType
2021
from nncf.experimental.torch_fx.model_transformer import FXModelTransformer
2122
from nncf.torch.graph.transformations.commands import PTTargetPoint
@@ -46,23 +47,20 @@ def fake_quantize_insertion_transformation(model: torch.fx.GraphModule):
4647
return fake_quantize_insertion_transformation
4748

4849

49-
def _set_module_to_the_graph_module(
50-
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
51-
) -> str:
52-
"""
53-
Sets given module to the given torch.fx.GraphModule with unique name.
54-
"""
55-
module_to_insert = module_to_insert
56-
module_name_in_model = (
57-
";".join(
58-
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
59-
)
60-
+ "_"
61-
+ str(id(module_to_insert))
62-
)
63-
assert not hasattr(model, module_name_in_model)
64-
setattr(model, module_name_in_model, module_to_insert)
65-
return module_name_in_model
50+
def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor):
51+
def bias_update_transformation(model: torch.fx.GraphModule):
52+
graph = model.graph
53+
target_node_name = node.node_name
54+
graph_node = FXModelTransformer.get_graph_node_by_name(graph, target_node_name)
55+
bias_node = next(iter(graph_node.users))
56+
with graph.inserting_before(bias_node):
57+
new_constant = create_getattr_from_value(model, graph, target_node_name + "shifted_bias", value)
58+
args = list(bias_node.args)
59+
args[1] = new_constant
60+
bias_node.args = tuple(args)
61+
graph.eliminate_dead_code()
62+
63+
return bias_update_transformation
6664

6765

6866
def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]):
@@ -150,3 +148,22 @@ def insert_one_qdq(
150148

151149
for user, dq_node in user_dq_nodes:
152150
user.replace_input_with(target_node, dq_node)
151+
152+
153+
def _set_module_to_the_graph_module(
154+
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
155+
) -> str:
156+
"""
157+
Sets given module to the given torch.fx.GraphModule with unique name.
158+
"""
159+
module_to_insert = module_to_insert
160+
module_name_in_model = (
161+
";".join(
162+
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
163+
)
164+
+ "_"
165+
+ str(id(module_to_insert))
166+
)
167+
assert not hasattr(model, module_name_in_model)
168+
setattr(model, module_name_in_model, module_to_insert)
169+
return module_name_in_model

nncf/quantization/algorithms/fast_bias_correction/algorithm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def __init__(
9393

9494
@property
9595
def available_backends(self) -> List[BackendType]:
96-
return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH]
96+
return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX]
97+
# return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH]
9798

9899
def _set_backend_entity(self, model: TModel) -> None:
99100
"""
@@ -116,6 +117,12 @@ def _set_backend_entity(self, model: TModel) -> None:
116117
from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend
117118

118119
self._backend_entity = PTFastBiasCorrectionAlgoBackend()
120+
elif model_backend == BackendType.TORCH_FX:
121+
from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import (
122+
FXFastBiasCorrectionAlgoBackend,
123+
)
124+
125+
self._backend_entity = FXFastBiasCorrectionAlgoBackend()
119126
else:
120127
raise nncf.UnsupportedBackendError(
121128
"Cannot return backend-specific entity because {} is not supported!".format(model_backend.value)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Dict, List, Optional, Tuple
13+
14+
import numpy as np
15+
import torch
16+
import torch.fx
17+
from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node
18+
19+
import nncf.torch.graph.operator_metatypes as om
20+
from nncf.common.graph import NNCFGraph
21+
from nncf.common.graph import NNCFNode
22+
from nncf.common.graph.definitions import NNCFGraphNodeType
23+
from nncf.common.graph.transformations.commands import TargetType
24+
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
25+
from nncf.experimental.tensor import Tensor
26+
from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand
27+
from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder
28+
from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend
29+
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
30+
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
31+
from nncf.torch.graph.transformations.commands import PTTargetPoint
32+
from nncf.torch.nncf_network import NNCFNetwork
33+
from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector
34+
35+
36+
class FXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend):
37+
TARGET_TYPE_TO_PT_INS_TYPE_MAP = {
38+
TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK,
39+
TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK,
40+
}
41+
42+
@staticmethod
43+
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
44+
if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION:
45+
port_id = None
46+
if target_type in FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP:
47+
target_type = FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type]
48+
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)
49+
50+
@staticmethod
51+
def create_bias_correction_command(
52+
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
53+
) -> PTBiasCorrectionCommand:
54+
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data))
55+
56+
@staticmethod
57+
def model_extraction_command(
58+
input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]]
59+
) -> PTModelExtractionCommand:
60+
return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]])
61+
62+
@staticmethod
63+
def mean_statistic_collector(
64+
channel_axis: int,
65+
inplace: bool,
66+
num_samples: Optional[int] = None,
67+
window_size: Optional[int] = None,
68+
) -> TensorCollector:
69+
return get_mean_statistic_collector(num_samples, channel_axis, window_size)
70+
71+
@staticmethod
72+
def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]:
73+
# Pytorch does not have name for extracted node
74+
return None, None
75+
76+
@staticmethod
77+
def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor:
78+
blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device)
79+
for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])):
80+
index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim))
81+
blob[index] = data[j].data
82+
return blob
83+
84+
@staticmethod
85+
def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor:
86+
# TODO: make a node_name_vs_node map to speed up the process
87+
from nncf.experimental.torch_fx.model_transformer import FXModelTransformer
88+
89+
bias_node = nncf_graph.get_next_nodes(node)[0]
90+
graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name)
91+
return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model))
92+
93+
@staticmethod
94+
def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]:
95+
return 0, 0
96+
97+
@staticmethod
98+
def process_model_output(raw_data: Dict, output_name: str) -> Tensor:
99+
return Tensor(raw_data)
100+
101+
@staticmethod
102+
def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
103+
weight_node = nncf_graph.get_previous_nodes(node)[1]
104+
return weight_node.node_type == "dequantize_per_channel"
105+
106+
@staticmethod
107+
def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
108+
# Assumes that all biases were unfused
109+
if node.metatype in (om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype):
110+
next_nodes = nncf_graph.get_next_nodes(node)
111+
if len(next_nodes) != 1:
112+
return False
113+
return next_nodes[0].metatype in (om.PTAddMetatype,)
114+
115+
@staticmethod
116+
def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]:
117+
return node.node_name, node.node_name

0 commit comments

Comments
 (0)