1111
1212from typing import Any , Callable , Dict , List , Tuple
1313
14- import numpy as np
1514import torch
1615
1716import nncf .torch .graph .operator_metatypes as om
2120from nncf .common .graph .transformations .commands import TargetType
2221from nncf .common .quantization .quantizer_propagation .structs import QuantizationTrait
2322from nncf .common .tensor_statistics .statistic_point import StatisticPoint
23+ from nncf .experimental .common .check_feature import is_experimental_torch_tracing_enabled
2424from nncf .experimental .common .tensor_statistics .collectors import AbsMaxReducer
2525from nncf .experimental .common .tensor_statistics .collectors import MaxAggregator
2626from nncf .experimental .common .tensor_statistics .collectors import TensorCollector
27+ from nncf .experimental .torch2 .commands import PT2ConstUpdateCommand
28+ from nncf .experimental .torch2 .commands import PT2InsertionCommand
29+ from nncf .experimental .torch2 .function_hook .nncf_graph .nncf_graph_builder import GraphModelWrapper
2730from nncf .quantization .algorithms .smooth_quant .backend import SmoothQuantAlgoBackend
2831from nncf .tensor import Tensor
2932from nncf .torch .graph .transformations .command_creation import create_command_to_update_weight
@@ -119,6 +122,9 @@ def get_abs_max_channel_collector(
119122
120123 @staticmethod
121124 def get_weight_value (node_with_weight : NNCFNode , model : NNCFNetwork , nncf_graph : NNCFGraph ) -> Tensor :
125+ if isinstance (model , GraphModelWrapper ):
126+ model = model .model
127+
122128 weight_node = get_const_node (node_with_weight , node_with_weight .metatype .weight_port_ids [0 ], nncf_graph )
123129 if weight_node is None :
124130 msg = f"{ node_with_weight } node has no weight node."
@@ -127,7 +133,12 @@ def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph:
127133 return Tensor (weight_data )
128134
129135 @staticmethod
130- def weight_update_command (node_with_weight : NNCFNode , weight_value : np .ndarray ) -> PTWeightUpdateCommand :
136+ def weight_update_command (
137+ node_with_weight : NNCFNode , nncf_graph : NNCFGraph , weight_value : torch .Tensor
138+ ) -> PTWeightUpdateCommand :
139+ if is_experimental_torch_tracing_enabled ():
140+ weight_node = get_const_node (node_with_weight , node_with_weight .metatype .weight_port_ids [0 ], nncf_graph )
141+ return PT2ConstUpdateCommand (weight_node , weight_value )
131142 return create_command_to_update_weight (node_with_weight , weight_value )
132143
133144 @staticmethod
@@ -145,6 +156,9 @@ def scale_insertion_command(
145156
146157 sq_multiply = SQMultiply (scale_value .shape )
147158 sq_multiply .scale = scale_value
159+
160+ if is_experimental_torch_tracing_enabled ():
161+ return PT2InsertionCommand (target_points = target_points , hook_module = sq_multiply )
148162 return PTSharedFnInsertionCommand (target_points , sq_multiply , scale_node_name )
149163
150164 @staticmethod
@@ -161,6 +175,10 @@ def get_weight_channel_axis(node: NNCFNode) -> int:
161175
162176 @staticmethod
163177 def is_node_with_shared_weight (node : NNCFNode , nncf_graph : NNCFGraph ) -> bool :
178+ if is_experimental_torch_tracing_enabled ():
179+ weight_node = get_const_node (node , node .metatype .weight_port_ids [0 ], nncf_graph )
180+ output_edges = nncf_graph .get_next_nodes (weight_node )
181+ return len (output_edges ) > 1
164182 return node .is_shared ()
165183
166184 @staticmethod
0 commit comments