diff --git a/nncf/common/factory.py b/nncf/common/factory.py index d3a05add911..0d3ab6bf8fd 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -58,7 +58,7 @@ def create(model: TModel) -> NNCFGraph: from nncf.torch.nncf_network import NNCFNetwork if isinstance(model, GraphModelWrapper): - return model.build_graph() + return model.get_graph() if isinstance(model, NNCFNetwork): return model.nncf.get_graph() msg = f"Unexpected type of model {type(model)} for TORCH backend" diff --git a/nncf/experimental/torch2/function_hook/extractor.py b/nncf/experimental/torch2/function_hook/extractor.py new file mode 100644 index 00000000000..bb85b318ddf --- /dev/null +++ b/nncf/experimental/torch2/function_hook/extractor.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import torch +from torch import nn + +import nncf +from nncf import nncf_logger +from nncf.common.graph.graph import NNCFNode +from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes +from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage +from nncf.torch.graph import operator_metatypes as om +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.model_graph_manager import get_const_data +from nncf.torch.model_graph_manager import get_const_data_on_port +from nncf.torch.model_graph_manager import get_const_node + +CONV_METATYPES = ( + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTDepthwiseConv1dSubtype, + om.PTDepthwiseConv2dSubtype, + om.PTDepthwiseConv3dSubtype, +) + + +class ExtractedFunc(nn.Module): + """ + Module to execute function with kwargs. + Support function only with one input. + + :param fn: Function to execute. + :param kwargs: Function arguments. + """ + + def __init__(self, fn: Callable[..., torch.Tensor], kwargs: Dict[str, Any]) -> None: + super().__init__() + self.fn = fn + self.kwargs = kwargs + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fn(x, **self.kwargs) + + +def apply_args_to_kwargs( + args: Sequence[Any], kwargs: Dict[str, Any], indexed_args: List[Tuple[int, str]] +) -> Dict[str, Any]: + """ + Applies the given arguments and keyword arguments to a dictionary of keyword arguments. + + :param args: The positional arguments. + :param kwargs: The keyword arguments. + :param indexed_args: The list of pairs of indexes and names. + :return: A dictionary of keyword arguments with the applied arguments and keyword arguments. + """ + args_dict: Dict[str, Any] = dict() + for idx, arg_name in indexed_args: + if idx < len(args): + args_dict[arg_name] = args[idx] + elif arg_name in kwargs: + args_dict[arg_name] = kwargs[arg_name] + + return args_dict + + +def extract_bn(model: nn.Module, graph: PTNNCFGraph, node: NNCFNode) -> ExtractedFunc: + """ + Extract batch_norm operation. + + :param model: Source model. + :param graph: Graph of source model. + :param node: Target batch_norm node. + :return: BatchNorm module with same attributes and parameters from source module or None. + """ + layer_attr = node.layer_attributes + if not isinstance(layer_attr, PT2OpLayerAttributes): + msg = f"Expected PT2OpLayerAttributes for input_node.layer_attributes, actual: {type(layer_attr)}" + raise nncf.InternalError(msg) + + # torch.batch_norm( + # 0 - input: Tensor, + # 1 - weight: Optional[Tensor] + # 2 - bias: Optional[Tensor] + # 3 - running_mean: Optional[Tensor] + # 4 - running_var: Optional[Tensor] + # 5 - training: _bool + # 6 - momentum: _float + # 7 - eps: _float + # 8 - cudnn_enabled: _bool + # ) -> Tensor: ... + + weight = get_const_data_on_port(model, graph, node, 1) + bias = get_const_data_on_port(model, graph, node, 2) + running_mean = get_const_data_on_port(model, graph, node, 3) + running_var = get_const_data_on_port(model, graph, node, 4) + + bn_kwargs = apply_args_to_kwargs( + layer_attr.op_args, + layer_attr.op_kwargs, + [(6, "momentum"), (7, "eps"), (8, "cudnn_enabled")], + ) + bn_kwargs["weight"] = weight + bn_kwargs["bias"] = bias + bn_kwargs["running_mean"] = running_mean + bn_kwargs["running_var"] = running_var + bn_kwargs["training"] = False + + return ExtractedFunc(layer_attr.func, bn_kwargs) + + +def extract_conv( + model: nn.Module, + graph: PTNNCFGraph, + input_node: NNCFNode, + output_node: NNCFNode, +) -> nn.Module: + """ + Extracts a convolutional layer from an NNCF graph and constructs an ExtractedFunc module. + + :param model: The NNCF network containing the layer. + :param graph: The NNCF graph. + :param input_nodes: The name of input node. + :param output_nodes: The name of output node. + :return: The extracted convolutional layer as an ExtractedFunc module. + """ + + # torch.conv*d( + # 0 - input: Tensor + # 1 - weight: Tensor + # 2 - bias: Optional[Tensor] + # 3 - stride: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] + # 4 - padding: Union[Union[_int, SymInt] | str + # 5 - dilation: Union[Union[_int, SymInt], Sequence[Union[_int, SymInt]]] + # 6 - groups: Union[_int, SymInt] + # ) -> Tensor: ... + + weight_node = get_const_node(input_node, 1, graph) + if weight_node is None: + msg = "Weight node not found for {input_node}" + raise nncf.InternalError(msg) + weight = get_const_data(weight_node, model) + + hook_storage = get_hook_storage(model) + with torch.no_grad(): + # Calculate weight after execution all hook fro weight data + weight = hook_storage.execute_post_function_hooks(weight_node.node_name, 0, weight) + weight = hook_storage.execute_pre_function_hooks(input_node.node_name, 1, weight) + + bias_node = get_const_node(input_node, 2, graph) + bias = get_const_data(bias_node, model) if bias_node is not None else None + + layer_attrs = input_node.layer_attributes + + if not isinstance(layer_attrs, PT2OpLayerAttributes): + msg = f"Expected PT2OpLayerAttributes for input_node.layer_attributes, actual: {type(layer_attrs)}" + raise nncf.InternalError(msg) + + conv_kwargs = apply_args_to_kwargs( + layer_attrs.op_args, + layer_attrs.op_kwargs, + [(3, "stride"), (4, "padding"), (5, "dilation"), (6, "groups")], + ) + conv_kwargs["weight"] = weight + conv_kwargs["bias"] = bias + conv_module = ExtractedFunc(layer_attrs.func, conv_kwargs) + + if input_node == output_node: + return conv_module + + if output_node.metatype is not om.PTBatchNormMetatype: + msg = f"Support only PTBatchNormMetatype as output node, actual: {output_node.metatype}" + raise nncf.InternalError(msg) + + next_nodes = graph.get_next_nodes(input_node) + if output_node not in next_nodes: + msg = f"Output node {output_node} not found after {input_node}" + raise nncf.InternalError(msg) + + bn_module = extract_bn(model, graph, output_node) + return nn.Sequential(conv_module, bn_module) + + +def extract_model( + model: nn.Module, graph: PTNNCFGraph, input_nodes: List[str], output_nodes: List[str] +) -> Optional[nn.Module]: + """ + Extracts a submodule from a given NNCF network containing only the nodes from the input to the output node. + + Supported subgraph: + - Conv + - Conv + BatchNorm + + :param model: The NNCF network to extract the submodule from. + :param input_nodes: List containing names of the input nodes for the submodule. + :param output_nodes: List containing names of the output nodes for the submodule. + :return: An nn.Module containing the extracted submodel, or None if extraction is not supported. + """ + + if len(input_nodes) != 1 or len(output_nodes) != 1: + msg = "input_nodes and output_nodes should contain only one node." + raise nncf.InternalError(msg) + + input_node = graph.get_node_by_name(input_nodes[0]) + output_node = graph.get_node_by_name(output_nodes[0]) + + if input_node.metatype in CONV_METATYPES: + return extract_conv(model, graph, input_node, output_node) + + nncf_logger.debug(f"Can`t extract module for {input_node.node_name}") + return None diff --git a/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py b/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py index 85993f15ad6..be8efd387be 100644 --- a/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py +++ b/nncf/experimental/torch2/function_hook/graph/build_graph_mode.py @@ -32,6 +32,7 @@ from nncf.experimental.torch2.function_hook.weak_map import WeakUnhashableKeyMap from nncf.experimental.torch2.function_hook.wrapper import ForwardWithHooks from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage +from nncf.torch.utils import training_mode_switcher class GraphBuilderMode(FunctionHookMode): @@ -358,12 +359,12 @@ def build_graph(model: nn.Module, *args: Any, **kwargs: Any) -> nx.MultiDiGraph: :param model: The PyTorch model for which the computational graph will be built. :return: A nx.MultiDiGraph where nodes represent operations of model. """ - - with torch.enable_grad(): # type: ignore - # Gradient use to get information about __get__ functions to detect tensor.(T, mT) attributes - with GraphBuilderMode(model=model, hook_storage=get_hook_storage(model)) as ctx: - args, kwargs = ctx.process_model_inputs(args, kwargs) - wrapped_forward = cast(ForwardWithHooks, model.forward) - outputs = wrapped_forward._func(*args, **kwargs) - outputs = ctx.process_model_outputs(outputs) + with training_mode_switcher(model, is_training=False): + with torch.enable_grad(): # type: ignore + # Gradient use to get information about __get__ functions to detect tensor.(T, mT) attributes + with GraphBuilderMode(model=model, hook_storage=get_hook_storage(model)) as ctx: + args, kwargs = ctx.process_model_inputs(args, kwargs) + wrapped_forward = cast(ForwardWithHooks, model.forward) + outputs = wrapped_forward._func(*args, **kwargs) + outputs = ctx.process_model_outputs(outputs) return ctx.graph diff --git a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py b/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py index b298ad5a7b0..f036891000f 100644 --- a/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py +++ b/nncf/experimental/torch2/function_hook/nncf_graph/nncf_graph_builder.py @@ -21,6 +21,7 @@ import nncf.torch.graph.operator_metatypes as om from nncf.common.graph.graph import NNCFNode from nncf.common.graph.layer_attributes import BaseLayerAttributes +from nncf.common.graph.layer_attributes import ConstantLayerAttributes from nncf.common.graph.layer_attributes import Dtype from nncf.experimental.torch2.function_hook.graph.build_graph_mode import build_graph from nncf.experimental.torch2.function_hook.graph.graph_utils import ConstMeta @@ -157,7 +158,8 @@ def get_layer_attributes( if isinstance(meta, FunctionMeta): constant_port_ids = get_constant_port_ids(nx_graph, node) return PT2OpLayerAttributes(meta.func, meta.args, meta.kwargs, constant_port_ids) - + if isinstance(meta, ConstMeta): + return ConstantLayerAttributes(meta.name_in_model, list(meta.shape)) return None @@ -228,17 +230,16 @@ class GraphModelWrapper: """ A class that wraps a PyTorch model with examples inputs and provides an interface to build a computational graph of the model. - - :param model: The PyTorch model to be wrapped. - :param example_input: A tuple of example input for the model. """ def __init__(self, model: nn.Module, example_input: Any) -> None: """ - Initialize the GraphModelWrapper. + :param model: The PyTorch model to be wrapped. + :param example_input: A tuple of example input for the model. """ self.model = model self.example_input = example_input + self.graph: Optional[PTNNCFGraph] = None def build_graph(self) -> PTNNCFGraph: """ @@ -254,3 +255,19 @@ def build_graph(self) -> PTNNCFGraph: if isinstance(self.example_input, tuple): return build_nncf_graph(self.model, *self.example_input) return build_nncf_graph(self.model, self.example_input) + + def get_graph(self) -> PTNNCFGraph: + """ + Returns the computational graph of the model. + + :return: The PTNNCFGraph representing the model. + """ + if self.graph is None: + self.graph = self.build_graph() + return self.graph + + def reset_graph(self) -> None: + """ + Resets the computational graph of the model. + """ + self.graph = None diff --git a/nncf/experimental/torch2/model_transformer.py b/nncf/experimental/torch2/model_transformer.py index 6654ae6550e..ae1ff896200 100644 --- a/nncf/experimental/torch2/model_transformer.py +++ b/nncf/experimental/torch2/model_transformer.py @@ -10,7 +10,7 @@ # limitations under the License. from collections import defaultdict -from typing import Dict, List +from typing import Any, Callable, Dict, List, Tuple, Type, cast from torch import nn @@ -23,7 +23,11 @@ from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.experimental.torch2.function_hook.wrapper import register_post_function_hook from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook +from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.model_graph_manager import update_fused_bias + +TRANSFORMATION_PAIRS = Tuple[Tuple[Type[Any], Callable[[GraphModelWrapper, List[Any]], GraphModelWrapper]], ...] class PT2ModelTransformer(ModelTransformer[GraphModelWrapper]): @@ -34,9 +38,10 @@ class PT2ModelTransformer(ModelTransformer[GraphModelWrapper]): def __init__(self, model: GraphModelWrapper): super().__init__(model) - self._command_transformation_ordered_pairs = [ - (PT2InsertionCommand, self._apply_insertion_transformation), - ] + self._command_transformation_ordered_pairs: TRANSFORMATION_PAIRS = ( + (PT2InsertionCommand, self._apply_insertion_transformations), + (PTBiasCorrectionCommand, self._apply_bias_correction_transformations), + ) def transform(self, transformation_layout: TransformationLayout) -> GraphModelWrapper: """ @@ -58,20 +63,23 @@ def transform(self, transformation_layout: TransformationLayout) -> GraphModelWr raise ValueError(msg) aggregated_transformations[transformation.__class__].append(transformation) - model = self._model.model - + model = self._model for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs: transformations = aggregated_transformations[transformation_cls] if transformations: - model = transformation_fn(model, transformations) # type: ignore[arg-type] - return self._model + model = transformation_fn(model, transformations) - def _apply_insertion_transformation( - self, model: nn.Module, transformations: List[PT2InsertionCommand] - ) -> nn.Module: + if aggregated_transformations.get(PT2InsertionCommand, []): + model.reset_graph() + return model + + def _apply_insertion_transformations( + self, wrapped_model: GraphModelWrapper, transformations: List[PT2InsertionCommand] + ) -> GraphModelWrapper: """ Applies insertion transformation to the model. + :param wrapped_model: Model to apply transformations. :param command: Insertion transformation command. """ for command in transformations: @@ -80,10 +88,31 @@ def _apply_insertion_transformation( handle_storage = command.handle_storage for target_point in target_points: - handle = insert_hook(model, hook_module, target_point) + handle = insert_hook(wrapped_model.model, hook_module, target_point) if handle_storage is not None: handle_storage.append(handle) - return model + return wrapped_model + + @staticmethod + def _apply_bias_correction_transformations( + wrapped_model: GraphModelWrapper, transformations: List[PTBiasCorrectionCommand] + ) -> GraphModelWrapper: + """ + Applies bias correction transformations on the model. + + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + :return: Model with corrected bias. + """ + for transformation in transformations: + pt_target_point = cast(PTTargetPoint, transformation.target_point) + update_fused_bias( + target_node_name=pt_target_point.target_node_name, + new_bias=transformation.bias_value, + nncf_graph=wrapped_model.get_graph(), + model=wrapped_model.model, + ) + return wrapped_model def insert_hook(model: nn.Module, hook: nn.Module, target_point: PTTargetPoint) -> RemovableHookHandle: diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index ba14710f7b1..4ee15603465 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -17,8 +17,11 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch2.function_hook.extractor import extract_model +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.tensor import Tensor from nncf.torch.graph.transformations.command_creation import create_bias_correction_command @@ -82,8 +85,10 @@ def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, ch return blob @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: NNCFNetwork) -> Tensor: - return Tensor(get_fused_bias_value(node, model)) + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: Union[NNCFNetwork, GraphModelWrapper]) -> Tensor: + if isinstance(model, GraphModelWrapper): + model = model.model + return Tensor(get_fused_bias_value(node, nncf_graph, model)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: @@ -109,5 +114,13 @@ def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFG return input_node_name, output_node_name @staticmethod - def get_activation_channel_axis(node: NNCFNode, pord_id: int, input_shape: Tuple[int]) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: return node.metatype.output_channel_axis + + def extract_submodel( + self, model_transformer: ModelTransformer, input_id: List[Tuple[str, int]], output_id: List[Tuple[str, int]] + ): + model = model_transformer._model + if isinstance(model, GraphModelWrapper): + return extract_model(model.model, model.get_graph(), [input_id[0]], [output_id[0]]) + return super().extract_submodel(model_transformer, input_id, output_id) diff --git a/nncf/torch/__init__.py b/nncf/torch/__init__.py index 957175b1aea..4627b8d3f8f 100644 --- a/nncf/torch/__init__.py +++ b/nncf/torch/__init__.py @@ -77,7 +77,4 @@ if torch.__version__ >= "2.5.0": from torch._dynamo.polyfills import loader -from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled - -if not is_experimental_torch_tracing_enabled(): - patch_torch_operators() +patch_torch_operators() diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 4c756ac465b..2b5242ae95f 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -24,6 +24,7 @@ import nncf from nncf import nncf_logger from nncf.common.utils.api_marker import api +from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled from nncf.torch.dynamic_graph.patch_pytorch_state import PATCHING_STATE from nncf.torch.dynamic_graph.structs import NamespaceTarget from nncf.torch.dynamic_graph.structs import PatchedOperatorInfo @@ -352,6 +353,9 @@ def remove_private_functions(names: List[str]) -> List[str]: def patch_torch_operators(): + if is_experimental_torch_tracing_enabled(): + return + # Only patch torch.jit.script during first patch_torch_operators call if not PATCHING_STATE.jit_is_wrapped: patch_torch_jit() diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 825c63c5765..d4ddc2639aa 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -22,6 +22,7 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import OperatorMetatypeRegistry from nncf.common.hardware.opset import HWConfigOpName +from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled from nncf.torch.dynamic_graph.graph import DynamicGraph from nncf.torch.dynamic_graph.structs import NamespaceTarget @@ -727,8 +728,15 @@ class PTBatchNormMetatype(PTOperatorMetatype): NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training", "cudnn_batch_norm"], } subtypes = [PTModuleBatchNormMetatype] - weight_port_ids = [3] - bias_port_id = 4 + + if is_experimental_torch_tracing_enabled(): + # torch.batch_norm + weight_port_ids = [1] + bias_port_id = 2 + else: + # torch.nn.functional.batch_norm + weight_port_ids = [3] + bias_port_id = 4 @PT_OPERATOR_METATYPES.register() diff --git a/nncf/torch/model_graph_manager.py b/nncf/torch/model_graph_manager.py index dff5230796f..a60b90246ff 100644 --- a/nncf/torch/model_graph_manager.py +++ b/nncf/torch/model_graph_manager.py @@ -12,6 +12,7 @@ from typing import List, Optional, Tuple, Type, Union import torch +from torch import nn import nncf from nncf.common.graph.graph import NNCFGraph @@ -51,8 +52,8 @@ def find_const_node_in_constant_subgraph(node: NNCFNode, graph: NNCFGraph) -> Op :return: The constant node found within the subgraph, or None if no constant node is found. """ if node.metatype == om.PTNoopMetatype or node.node_type in om.QUANTIZE_NODE_TYPES: - prev_nodes = graph.get_previous_nodes(node) - if len(prev_nodes) != 1: + prev_nodes = [e.from_node for e in graph.get_input_edges(node)] + if not prev_nodes: return None return find_const_node_in_constant_subgraph(prev_nodes[0], graph) if node.metatype in CONST_NOOP_METATYPES: @@ -118,12 +119,12 @@ def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Mod return curr_module -def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor: +def get_const_data(const_node: NNCFNode, model: nn.Module) -> torch.Tensor: """ Retrieves a detached constant tensor associated with a given node. :param const_node: The node associated with const data. - :param model: The NNCFNetwork object. + :param model: The nn.Module object. :return: A torch.Tensor object containing the constant value. """ const_name = const_node.layer_attributes.name @@ -135,16 +136,16 @@ def get_const_data(const_node: NNCFNode, model: NNCFNetwork) -> torch.Tensor: return data.detach() -def get_const_data_on_port(node: NNCFNode, port_id: int, model: NNCFNetwork) -> torch.Tensor: +def get_const_data_on_port(model: nn.Module, graph: NNCFGraph, node: NNCFNode, port_id: int) -> torch.Tensor: """ Retrieves a constant tensor associated with a given node and input port in an NNCF graph. + :param model: The nn.Module object. + :param graph: The NNCF graph containing the nodes. :param node: The node to retrieve the constant from. :param port_id: The port id within the node that holds the constant. - :param model: The NNCFNetwork object. :return: A torch.Tensor object containing the constant value, or None if the constant is not found. """ - graph = model.nncf.get_graph() const_node = get_const_node(node, port_id, graph) if const_node is None: return None @@ -189,30 +190,30 @@ def is_node_with_fused_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return bias is not None -def get_fused_bias_value(node: NNCFNode, model: NNCFNetwork) -> Optional[torch.Tensor]: +def get_fused_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: nn.Module) -> Optional[torch.Tensor]: """ Returns the bias tensor for the node or for potential fused node. :param node: The node that corresponds to the operation with bias. + :param nncf_graph: The NNCF graph. :param model: The model that contains this operation. :return: The bias value that is applied to the output tensor of the node's operation. """ - nncf_graph = model.nncf.get_graph() fused_node = get_potential_fused_node(node.node_name, nncf_graph) - bias = get_const_data_on_port(node, node.metatype.bias_port_id, model) + bias = get_const_data_on_port(model, nncf_graph, node, node.metatype.bias_port_id) if fused_node is None: return bias - fused_bias = get_const_data_on_port(fused_node, fused_node.metatype.bias_port_id, model) + fused_bias = get_const_data_on_port(model, nncf_graph, fused_node, fused_node.metatype.bias_port_id) if bias is None: return fused_bias - fused_weight = get_const_data_on_port(fused_node, fused_node.metatype.weight_port_ids[0], model) + fused_weight = get_const_data_on_port(model, nncf_graph, fused_node, fused_node.metatype.weight_port_ids[0]) return bias * fused_weight + fused_bias -def update_fused_bias(target_node_name: str, new_bias: torch.Tensor, model: NNCFNetwork) -> None: +def update_fused_bias(target_node_name: str, new_bias: torch.Tensor, nncf_graph: NNCFGraph, model: nn.Module) -> None: """ Update bias for target module or potential fused module. @@ -220,11 +221,10 @@ def update_fused_bias(target_node_name: str, new_bias: torch.Tensor, model: NNCF :param new_bias: New bias value. :param model: The model. """ - nncf_graph = model.nncf.get_graph() target_node = nncf_graph.get_node_by_name(target_node_name) fused_node = get_potential_fused_node(target_node_name, nncf_graph) if fused_node is None: - set_const_data_to_port_id(new_bias, target_node, target_node.metatype.bias_port_id, model) + set_const_data_to_port_id(new_bias, target_node, target_node.metatype.bias_port_id, nncf_graph, model) return target_bias_node = get_const_node(target_node, target_node.metatype.bias_port_id, nncf_graph) @@ -256,13 +256,13 @@ def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[int]: return weight_port_ids -def set_const_data(data: torch.Tensor, const_node: NNCFNode, model: NNCFNetwork) -> None: +def set_const_data(data: torch.Tensor, const_node: NNCFNode, model: nn.Module) -> None: """ Sets the constant data associated with a specific constant node in an NNCF network model. :param data: The constant data tensor to be set. :param const_node: The NNCF node representing the constant data. - :param model: The NNCF network model. + :param model: The model. """ const_name = const_node.layer_attributes.name module_name, const_attr_name = split_const_name(const_name) @@ -274,16 +274,17 @@ def set_const_data(data: torch.Tensor, const_node: NNCFNode, model: NNCFNetwork) setattr(module, const_attr_name, data) -def set_const_data_to_port_id(data: torch.Tensor, node: NNCFNode, port_id: int, model: NNCFNetwork) -> None: +def set_const_data_to_port_id( + data: torch.Tensor, node: NNCFNode, port_id: int, graph: NNCFGraph, model: nn.Module +) -> None: """ - Sets the value of a constant tensor within a specified node in an NNCFNetwork. + Sets the value of a constant tensor within a specified node in the target model. :param data: The tensor containing the new value to be set for the constant. :param node: The NNCF node representing the operation that uses the constant. :param const_port_id: The input port id of the node that receives the constant. :param model: The NNCF network containing the module to be modified. """ - graph = model.nncf.get_graph() const_node = get_const_node(node, port_id, graph) if const_node is None: msg = f"No found node with constant for {node.node_name} on {port_id} port" diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index e8954e77354..0e9330449de 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -221,6 +221,7 @@ def _apply_bias_correction_transformations( update_fused_bias( target_node_name=transformation.target_point.target_node_name, new_bias=transformation.bias_value, + nncf_graph=model.nncf.get_graph(), model=model, ) return model diff --git a/tests/torch/ptq/test_fast_bias_correction.py b/tests/torch/ptq/test_fast_bias_correction.py index 1fa76053605..e21ebde6239 100644 --- a/tests/torch/ptq/test_fast_bias_correction.py +++ b/tests/torch/ptq/test_fast_bias_correction.py @@ -55,7 +55,7 @@ def check_bias(model: NNCFNetwork, ref_bias: list): for node in nncf_graph.get_all_nodes(): if not is_node_with_fused_bias(node, nncf_graph): continue - bias_value = get_fused_bias_value(node, model) + bias_value = get_fused_bias_value(node, nncf_graph, model).cpu() # TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189 assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" return @@ -77,17 +77,3 @@ def backend_specific_model(model: bool, tmp_dir: str): @staticmethod def fn_to_type(tensor): return torch.Tensor(tensor).cuda() - - @staticmethod - def check_bias(model: NNCFNetwork, ref_bias: list): - ref_bias = torch.Tensor(ref_bias) - nncf_graph = NNCFGraphFactory.create(model) - for node in nncf_graph.get_all_nodes(): - if not is_node_with_fused_bias(node, nncf_graph): - continue - bias_value = get_fused_bias_value(node, model).cpu() - # TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189 - assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" - return - msg = "Not found node with bias" - raise ValueError(msg) diff --git a/tests/torch/test_model_graph_manager.py b/tests/torch/test_model_graph_manager.py index e841fc24e40..8f2567efd65 100644 --- a/tests/torch/test_model_graph_manager.py +++ b/tests/torch/test_model_graph_manager.py @@ -178,7 +178,7 @@ def test_get_const_data_on_port(self, model_desc, port_id): model_name, desc = model_desc ref = self.REF_GET_CONST_DATA[model_name][port_id - 1] - data = get_const_data_on_port(desc.node, port_id, desc.model) + data = get_const_data_on_port(desc.model, desc.model.nncf.get_graph(), desc.node, port_id) if ref is None: assert data is None else: @@ -338,7 +338,7 @@ def test_get_fused_bias_value(model_cls, ref): graph = model.nncf.get_graph() target_node = graph.get_nodes_by_types("conv2d")[0] - bias = get_fused_bias_value(target_node, model) + bias = get_fused_bias_value(target_node, graph, model) assert torch.all(torch.isclose(bias, torch.tensor(ref))) @@ -356,8 +356,8 @@ def test_update_fused_bias(model_cls): graph = model.nncf.get_graph() target_node = graph.get_nodes_by_types("conv2d")[0] - update_fused_bias(target_node.node_name, ref_new_bias, model) - bias = get_fused_bias_value(target_node, model) + update_fused_bias(target_node.node_name, ref_new_bias, graph, model) + bias = get_fused_bias_value(target_node, graph, model) assert torch.all(torch.isclose(bias, ref_new_bias)) if model_cls == helpers.ConvTestModel: diff --git a/tests/torch2/function_hook/quantization/test_fast_bias_correction.py b/tests/torch2/function_hook/quantization/test_fast_bias_correction.py new file mode 100644 index 00000000000..09c67004104 --- /dev/null +++ b/tests/torch2/function_hook/quantization/test_fast_bias_correction.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import pytest +import torch + +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper +from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend +from nncf.torch.model_graph_manager import get_fused_bias_value +from nncf.torch.model_graph_manager import is_node_with_fused_bias +from tests.cross_fw.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm + + +class TestTorchFBCAlgorithm(TemplateTestFBCAlgorithm): + @staticmethod + def list_to_backend_type(data: List) -> torch.Tensor: + return torch.Tensor(data) + + @staticmethod + def get_backend() -> PTFastBiasCorrectionAlgoBackend: + return PTFastBiasCorrectionAlgoBackend + + @staticmethod + def backend_specific_model(model: bool, tmp_dir: str): + return GraphModelWrapper(wrap_model(model), torch.ones(model.INPUT_SIZE)) + + @staticmethod + def fn_to_type(tensor): + return torch.Tensor(tensor) + + @staticmethod + def get_transform_fn(): + def transform_fn(data_item): + tensor, _ = data_item + return tensor + + return transform_fn + + @staticmethod + def check_bias(model: GraphModelWrapper, ref_bias: list): + ref_bias = torch.Tensor(ref_bias) + nncf_graph = model.get_graph() + for node in nncf_graph.get_all_nodes(): + if not is_node_with_fused_bias(node, nncf_graph): + continue + bias_value = get_fused_bias_value(node, nncf_graph, model.model).cpu() + # TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189 + assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" + return + msg = "Not found node with bias" + raise ValueError(msg) + + +@pytest.mark.cuda +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skipping for CPU-only setups") +class TestTorchCudaFBCAlgorithm(TestTorchFBCAlgorithm): + @staticmethod + def list_to_backend_type(data: List) -> torch.Tensor: + return torch.Tensor(data).cuda() + + @staticmethod + def backend_specific_model(model: bool, tmp_dir: str): + return GraphModelWrapper(wrap_model(model.cuda()), torch.ones(model.INPUT_SIZE).cuda()) + + @staticmethod + def fn_to_type(tensor): + return torch.Tensor(tensor).cuda() diff --git a/tests/torch2/function_hook/test_extractor.py b/tests/torch2/function_hook/test_extractor.py new file mode 100644 index 00000000000..e786896ec69 --- /dev/null +++ b/tests/torch2/function_hook/test_extractor.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from torch import nn + +import tests.cross_fw.test_templates.helpers as helpers +from nncf.experimental.torch2.function_hook.extractor import extract_model +from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph +from nncf.experimental.torch2.function_hook.wrapper import register_pre_function_hook +from nncf.experimental.torch2.function_hook.wrapper import wrap_model +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import QuantizationMode +from nncf.torch.quantization.layers import SymmetricQuantizer + +TEST_PARAMS = ( + ( + helpers.ConvBiasBNTestModel, + "conv/conv2d/0", + "bn/batch_norm/0", + ), + ( + helpers.ConvBNTestModel, + "conv/conv2d/0", + "bn/batch_norm/0", + ), + ( + helpers.ConvTestModel, + "conv/conv2d/0", + "conv/conv2d/0", + ), + ( + helpers.CustomConvBNTestModel, + "conv/conv2d/0", + "bn/batch_norm/0", + ), + ( + helpers.CustomConvTestModel, + "conv/conv2d/0", + "conv/conv2d/0", + ), +) + + +@pytest.mark.parametrize("model_cls, input_node_name, output_node_name", TEST_PARAMS) +def test_extract_model(model_cls: type, input_node_name: str, output_node_name: str): + example_input = torch.ones(model_cls.INPUT_SIZE) + + model: nn.Module = wrap_model(model_cls().eval()) + graph = build_nncf_graph(model, example_input) + + extracted_module = extract_model(model, graph, [input_node_name], [output_node_name]) + with torch.no_grad(): + ret1 = model(example_input) + ret2 = extracted_module(example_input) + assert torch.any(torch.isclose(ret1, ret2)) + + +@pytest.mark.parametrize("model_cls, input_node_name, output_node_name", TEST_PARAMS) +def test_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_name): + example_input = torch.ones(model_cls.INPUT_SIZE) + + model = wrap_model(model_cls().eval()) + + qspec = PTQuantizerSpec( + num_bits=8, + mode=QuantizationMode.SYMMETRIC, + signedness_to_force=None, + scale_shape=(1,), + narrow_range=False, + half_range=False, + logarithm_scale=False, + ) + fq = SymmetricQuantizer(qspec) + + register_pre_function_hook(model, input_node_name, 1, fq) + + graph = build_nncf_graph(model, example_input) + + extracted_module = extract_model(model, graph, [input_node_name], [output_node_name]) + with torch.no_grad(): + ret1 = model(example_input) + ret2 = extracted_module(example_input) + assert torch.all(torch.isclose(ret1, ret2)) + + extracted_fn = extracted_module + if isinstance(extracted_fn, nn.Sequential): + extracted_fn = extracted_module[0]