-
Couldn't load subscription status.
- Fork 700
[Quantized DeConv Support] Dynamically Quantized Deconvolutions with groups ==1 #11864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,10 @@ | |
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from executorch.backends.xnnpack.utils.utils import is_depthwise_conv | ||
| from executorch.backends.xnnpack.utils.utils import ( | ||
| get_groups_from_conv, | ||
| is_depthwise_conv, | ||
| ) | ||
| from torch._subclasses import FakeTensor | ||
| from torch.fx import Node | ||
| from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( | ||
|
|
@@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None: | |
| return decorator | ||
|
|
||
|
|
||
| def change_quantization_config( | ||
| original_qspec, | ||
| dtype=None, | ||
| quant_min=None, | ||
| quant_max=None, | ||
| qscheme=None, | ||
| ch_axis=None, | ||
| is_dynamic=None, | ||
| observer_or_fake_quant_ctr=None, | ||
| ): | ||
| return QuantizationSpec( | ||
| dtype=dtype or original_qspec.dtype, | ||
| quant_min=quant_min or original_qspec.quant_min, | ||
| quant_max=quant_max or original_qspec.quant_max, | ||
| qscheme=qscheme or original_qspec.qscheme, | ||
| ch_axis=ch_axis or original_qspec.ch_axis, | ||
| is_dynamic=is_dynamic or original_qspec.is_dynamic, | ||
| observer_or_fake_quant_ctr=observer_or_fake_quant_ctr | ||
| or original_qspec.observer_or_fake_quant_ctr, | ||
| ) | ||
|
|
||
|
|
||
| def is_relu_node(node: Node) -> bool: | ||
| """ | ||
| Check if a given node is a relu node | ||
|
|
@@ -231,31 +256,44 @@ def _do_annotate_conv( | |
| if is_relu_node(user): | ||
| continue | ||
|
|
||
| # Tracks conditions for whether or not to skip | ||
| skip = False | ||
|
|
||
| input_qspec_map = {} | ||
| input_act = conv_node.args[0] | ||
| assert isinstance(input_act, Node) | ||
| input_qspec_map[input_act] = get_input_act_qspec(quantization_config) | ||
|
|
||
| weight = conv_node.args[1] | ||
| assert isinstance(weight, Node) | ||
| input_qspec_map[weight] = get_weight_qspec(quantization_config) | ||
| weight_qspec = get_weight_qspec(quantization_config) | ||
| num_groups = get_groups_from_conv(conv_node) | ||
|
|
||
| # Only annotate dynamically quantized conv if it's 2D and not depthwise | ||
| if ( | ||
| # skip if transposed conv has more than 1 group | ||
| skip = skip or (is_conv_transpose and num_groups != 1) | ||
| print(f"{skip} conv transpose and num_groups") | ||
|
|
||
| if is_conv_transpose: | ||
| # transposed convs per output channel quantization | ||
| weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we support group > 1 here in the quant flow? I know we are checking above but just curious. If yes, we can move this before the group check, if not then add an assert to avoid future issues when we allow groups > 1 for transposed_conv. |
||
|
|
||
| input_qspec_map[weight] = weight_qspec | ||
| is_dynamic = ( | ||
| quantization_config | ||
| and quantization_config.input_activation | ||
| and quantization_config.input_activation.is_dynamic | ||
| ): | ||
| ) | ||
|
|
||
| # Only annotate dynamically quantized conv if it's 2D and not depthwise | ||
| if is_dynamic: | ||
| weight_val = weight.meta.get("val", None) | ||
| weight_shape = getattr(weight_val, "shape", None) | ||
|
|
||
| # Skip if not a 4D weight tensor (i.e. not conv2d) | ||
| if weight_shape is not None and len(weight_shape) != 4: | ||
| continue | ||
|
|
||
| skip = skip or (weight_shape is not None and len(weight_shape) != 4) | ||
| # Skip if depthwise (default to groups=1 since it's not an arg) | ||
| if is_depthwise_conv(weight_shape, 1, is_conv_transpose): | ||
| continue | ||
| skip = skip or ( | ||
| not is_conv_transpose and is_depthwise_conv(weight_shape, 1, False) | ||
| ) | ||
|
|
||
| # adding weight node to the partition as well | ||
| partition = [conv_node, conv_node.args[1]] | ||
|
|
@@ -265,7 +303,7 @@ def _do_annotate_conv( | |
| input_qspec_map[bias] = get_bias_qspec(quantization_config) | ||
| partition.append(bias) | ||
|
|
||
| if _is_annotated(partition): | ||
| if _is_annotated(partition) or skip: | ||
| continue | ||
|
|
||
| if filter_fn and any(not filter_fn(n) for n in partition): | ||
|
|
@@ -311,7 +349,12 @@ def _do_annotate_conv_relu( | |
|
|
||
| weight = conv_node.args[1] | ||
| assert isinstance(weight, Node) | ||
| input_qspec_map[weight] = get_weight_qspec(quantization_config) | ||
| weight_qspec = get_weight_qspec(quantization_config) | ||
| groups = get_groups_from_conv(conv_node) | ||
| if is_conv_transpose: | ||
| # transposed convs per output channel quantization | ||
| weight_qspec = change_quantization_config(weight_qspec, ch_axis=1) | ||
| input_qspec_map[weight] = weight_qspec | ||
|
|
||
| # adding weight node to the partition as well | ||
| partition = [relu_node, conv_node, conv_node.args[1]] | ||
|
|
@@ -323,6 +366,9 @@ def _do_annotate_conv_relu( | |
| if _is_annotated(partition): | ||
| continue | ||
|
|
||
| if is_conv_transpose and groups != 1: | ||
| continue | ||
|
|
||
| if filter_fn and any(not filter_fn(n) for n in partition): | ||
| continue | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove