Skip to content
27 changes: 18 additions & 9 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.hardware.config import PTHWConfig
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids
from nncf.torch.model_graph_manager import get_weight_channel_axes
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.quantization.layers import QUANTIZATION_MODULES
Expand Down Expand Up @@ -149,8 +150,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point

@staticmethod
def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]:
# TODO(dlyakhov): support transpose conv and other cases
return (0,)
return get_weight_channel_axes(node.metatype, ndims, target_point.input_port_id)

@staticmethod
def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]:
Expand All @@ -177,16 +177,25 @@ def _get_input_scale_shape(
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
is_weights = target_point.is_weight_target_point()
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)

if is_weights:
# TODO(dlyakhov): support transpose conv/ make channel_idx common
channel_idx = 0
node = nncf_graph.get_node_by_name(target_point.target_node_name)
channel_axes = get_weight_channel_axes(node.metatype, len(input_shape), target_point.input_port_id)
else:
channel_idx = 1 # channel dim for activations
channel_axes = [1]

input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point)
scale_shape = tuple(
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx)
)
channel_idx = channel_axes[0] if channel_axes else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
channel_idx = channel_axes[0] if channel_axes else 0

Since channel axes is already being checked and handled in the if-else block below. channel_axes[0] can directly be passed to channel_idx parameter of get_scale_shape

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to return channel_idx in this function, so I think it is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay


if not len(channel_axes):
scale_shape = (1,)
else:
scale_shape = tuple(get_scale_shape(
input_shape,
is_weights=is_weights,
per_channel=per_channel,
channel_idx=channel_idx
))

return input_shape, scale_shape, channel_idx

Expand Down
Loading