-
Notifications
You must be signed in to change notification settings - Fork 259
[FX] Support weight quantization for operations where weight_port_id
!= 1
#3334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 7 commits
b8203a5
474e6b7
7319f9f
3308b40
0ee4f8b
74f677a
3565ba3
7efa675
e8ab53a
5bb69f7
3104e58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -40,6 +40,7 @@ | |||||
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand | ||||||
from nncf.torch.hardware.config import PTHWConfig | ||||||
from nncf.torch.model_graph_manager import get_weight_tensor_port_ids | ||||||
from nncf.torch.model_graph_manager import get_weight_channel_axes | ||||||
from nncf.torch.nncf_network import NNCFNetwork | ||||||
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT | ||||||
from nncf.torch.quantization.layers import QUANTIZATION_MODULES | ||||||
|
@@ -149,8 +150,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point | |||||
|
||||||
@staticmethod | ||||||
def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: | ||||||
# TODO(dlyakhov): support transpose conv and other cases | ||||||
return (0,) | ||||||
return get_weight_channel_axes(node.metatype, ndims, target_point.input_port_id) | ||||||
|
||||||
@staticmethod | ||||||
def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: | ||||||
|
@@ -177,16 +177,25 @@ def _get_input_scale_shape( | |||||
nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool | ||||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: | ||||||
is_weights = target_point.is_weight_target_point() | ||||||
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) | ||||||
|
||||||
if is_weights: | ||||||
# TODO(dlyakhov): support transpose conv/ make channel_idx common | ||||||
channel_idx = 0 | ||||||
node = nncf_graph.get_node_by_name(target_point.target_node_name) | ||||||
channel_axes = get_weight_channel_axes(node.metatype, len(input_shape), target_point.input_port_id) | ||||||
else: | ||||||
channel_idx = 1 # channel dim for activations | ||||||
channel_axes = [1] | ||||||
|
||||||
input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) | ||||||
scale_shape = tuple( | ||||||
get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) | ||||||
) | ||||||
channel_idx = channel_axes[0] if channel_axes else 0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Since channel axes is already being checked and handled in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay |
||||||
|
||||||
if is_weights and not channel_axes: | ||||||
|
if is_weights and not channel_axes: | |
if not len(channel_axes): |
to cover the case of vector weights which are being quantized per channel
Uh oh!
There was an error while loading. Please reload this page.