Skip to content

Conversation

@AkiSakurai
Copy link
Contributor

@AkiSakurai AkiSakurai commented Jan 5, 2025

  • Add support for Transposed Convolution in XNNPACK delegate.
  • The test is copied from conv2d.
  • Some patterns are not quantized by XNNPACKQuantizer, and the quantization check is skipped.
  • Skip fusion of ReLU if there are multiple uses of the 2d Transposed Convolution node.

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7514

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 731506f with merge base 3f9324c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 5, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2025

Didn't find following labels among repository labels: release notes: Support 2d Transposed Convolution in XNNPACK delegate

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 5, 2025

Didn't find following labels among repository labels: backends

@AkiSakurai
Copy link
Contributor Author

@pytorchbot label "release notes: backends"

@mergennachin mergennachin added the module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ label Jan 8, 2025
Copy link
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is brilliant. Thanks for adding this. This was on our list for a while. Great quality PR!

Added a few nits, and comments. I am happy with this, let's fix these and we can merge this.

In the meanwhile, let me run some more internal tests to make sure we are looking good.

Thanks again.

self._test(
ConvTranspose2d(bias=has_bias),
quant_config=get_symmetric_quantization_config(),
check_quantized=False, # XNNPackQuantizer does not this pattern
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we should be able to add support for it, here

self._test(ConvTranspose2d(groups=2, in_channels=2, out_channels=6))

def test_fp32_conv_transpose2d_bn(self):
class ConvTranspose2dBatchNorm(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this out or use the more complex version here?

except:
torchao_installed = False

# Set higher recompile limit to avoid exception on over-recompilation in tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How this is related to deconv? Just ran into this when running linear test?

@@ -0,0 +1,430 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding these tests. I assume it was non-trivial to overload conv tests for transpose case?

fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
swap_in_out_for_transpose_weights: bool to indicate whether tensor shape should be
permuted and reshape from (inc, oc/groups, height, width) to (oc, inc/groups, height, width)
groups: number of groups for swap_in_out_for_transpose_weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert groups == 1 when the new flag is false.

else:
assert f"Unsupported weight per channel quantization axis for depthwise conv2d: {quant_params.axis}, expecting 0."

if swap_in_out_for_transpose_weights and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we cleanly merge with depthwise clause above, OK if it gets harder to read.

kernel_shape = get_shape(kernel_node)
stride = cast(List[int], conv.args[3])

if is_transpose and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add a comment about kernel shape and strides matching

@facebook-github-bot
Copy link
Contributor

@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

from typing import Optional

import executorch.extension.pybindings.portable_lib # noqa[F401]
import executorch.kernels.quantized # noqa[F401]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing local testing I am running into ModuleNotFoundError: No module named 'executorch.kernels'. I guess we need this to run the quantized ops when not lowered to XNNPACK. Can we skip running in the case of quantized graph and stop at serializing, perhaps that will eliminate the need to include this module?

Copy link
Contributor

@mcr229 mcr229 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AkiSakurai Thank you for this contribution! This is super valuable, and appreciate the work put in to this, and this is the first external operator contribution from a user! Would love to hear some feedback on what it was like working in the code base. Any ways to improve our code quality, readability, documentation would all be great to hear.

  • What prompted the contribution? Personal feature enablement? Model enablement?

  • What part of the development flow took the longest setting up?

  • How easy was it to test and verify your code changes?

  • What part of the code was the hardest to navigate and the most confusing? Quantization? Passes? Schema Serialization?

  • What(if any) parts of the code were easy to navigate and understand?

op_name = cast(torch._ops.OpOverload, node.target).name()

# Weight and Input should both be quantized
if op_name == exir_ops.edge.aten.convolution.default.name():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the reasoning for this? I imagine it should've returned true in the previous implementation as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue arises when a non-quantized operation is interleaved between two quantized operations. That operation also match the dequantize-op-quantize pattern. However, an operation with quantized input and float weight is not supported by XNNPACK.

and weight_quant_params.per_channel
and (groups > 1 or weight_quant_params.axis != 1)
):
why(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty neat way of checking this constraint

Copy link
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. The quantized import is no longer complaining.

@digantdesai
Copy link
Contributor

@mcr229 - Added #7640 FYI.

@digantdesai
Copy link
Contributor

@AkiSakurai I think you broke this test in a good way. Can you please fix it? 🙏

@digantdesai
Copy link
Contributor

Once this is fixed, I can merge this.

@facebook-github-bot
Copy link
Contributor

@digantdesai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@digantdesai digantdesai added release notes: xnnpack Changes to the XNNPack backend delegate and removed release notes: backends [DO NOT USE] labels Jan 14, 2025
@digantdesai digantdesai merged commit 63e6136 into pytorch:main Jan 14, 2025
6 checks passed
YIWENX14 pushed a commit that referenced this pull request Jan 28, 2025
* Support Transposed Convolution in XNNPACK delegate

* Apply suggestions

* Remove invalid restriction for transpose convolution batch normalization fusion

* fix size analysis tool test
@digantdesai
Copy link
Contributor

digantdesai commented Feb 6, 2025

FYI #8090

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ release notes: xnnpack Changes to the XNNPack backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants