Skip to content

Conversation

@roman-janik-nxp
Copy link
Collaborator

@roman-janik-nxp roman-janik-nxp commented Aug 20, 2025

Summary

This PR adds delegation of aten.conv1d to Neutron. Fixes input_shapes type hint in to_quantized_edge_program(). Fixes operators_not_to_delegate assignment in partitioner.

Test plan

Unit tests provided in backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py.

cc @digantdesai @JakeStevens @robert-kalmar

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/13549

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit a8ffed0 with merge base bd92f1a (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 20, 2025
@roman-janik-nxp
Copy link
Collaborator Author

@pytorchbot label "module: nxp" "release notes: nxp"

@pytorch-bot pytorch-bot bot added module: nxp Issues related to NXP Neutron NPU delegation and code under backends/nxp/ release notes: nxp Changes to the NXP Neutron backend delegate labels Aug 20, 2025
@robert-kalmar robert-kalmar self-requested a review August 20, 2025 14:02
Copy link
Contributor

@JakeStevens JakeStevens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this intersect with #13576? Seems you are extending with 0s, I think this will re-introduce this bug for conv1d now.

Can use the same checks introduced in the tests in #13576 for 1D case

@roman-janik-nxp roman-janik-nxp force-pushed the upstream/main-nxp/EIEX-449-upstream-1d-convolution branch from 0d530e4 to 5182658 Compare August 21, 2025 14:34
@roman-janik-nxp
Copy link
Collaborator Author

How does this intersect with #13576? Seems you are extending with 0s, I think this will re-introduce this bug for conv1d now.

Can use the same checks introduced in the tests in #13576 for 1D case

Yes, this is a conflict. I added fix for padding with zero-point to this PR. In our IR, Conv1D operator doesn't exist and it is emulated by changing to Conv2D and then the result is converted to Conv1D format. So this fix will change it also for Conv2D. However, the other PR is still needed as it changes Average pool.

+ fix input_shapes type hint in to_quantized_edge_program()
+ add test cases for Conv1D operator
+ add fix for padding with zero-point
@roman-janik-nxp roman-janik-nxp force-pushed the upstream/main-nxp/EIEX-449-upstream-1d-convolution branch from 5182658 to a8ffed0 Compare August 22, 2025 13:00
@robert-kalmar
Copy link
Collaborator

@JakeStevens , Roman updated based on your findings. Can you please re-review.
The PR cannot be merged without you closing the "change request".

def extend_1d_padding_to_2d(tflite_1d_padding: MutableSequence):
"""Extend the PyTorch 'padding' operator attribute that represents padding for a 1D kernel to 2D, by adding '0's."""
if tflite_1d_padding is not None:
tflite_1d_padding.append(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my specific concern. We are padding with zeros not zero point, just like the 2d case previously

Copy link
Collaborator Author

@roman-janik-nxp roman-janik-nxp Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value express the amount of padding applied to the input, not the padding value. So zero here means that the tensor will not be padded on H dimension as this is conversion from 1D tensor NWC to 2D NHWC tensor - padding also needs to be extended/converted. The amount of padding for W dim is kept.
The padding value is the zero-point added on L352 and L390 in convolution_converter.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!

@JakeStevens
Copy link
Contributor

Sorry, I think my comment was misunderstood. I was not concerned that this doesn't include the changes from the 2D padding PR-- both will get into mainline eventually, this is fine and makes perfect sense.

instead, I believe this PR has the same bug; in the conversion 1d case, we are extending the existing padding by adding a zero, instead of zero point. see inline comment

def extend_1d_padding_to_2d(tflite_1d_padding: MutableSequence):
"""Extend the PyTorch 'padding' operator attribute that represents padding for a 1D kernel to 2D, by adding '0's."""
if tflite_1d_padding is not None:
tflite_1d_padding.append(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks!

@robert-kalmar robert-kalmar merged commit 25973c1 into pytorch:main Aug 25, 2025
101 of 104 checks passed
@robert-kalmar robert-kalmar deleted the upstream/main-nxp/EIEX-449-upstream-1d-convolution branch August 26, 2025 07:16
agrima1304 pushed a commit to agrima1304/executorch that referenced this pull request Aug 26, 2025
…3549)

### Summary
Add delegation of `aten.conv1d` to Neutron. Fixes
`input_shapes` type hint in `to_quantized_edge_program()`. Fixes
`operators_not_to_delegate` assignment in partitioner.

### Test plan
Unit tests provided in
backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py.


cc @digantdesai @JakeStevens @robert-kalmar
kimishpatel pushed a commit that referenced this pull request Sep 2, 2025
### Summary
Add delegation of `aten.conv1d` to Neutron. Fixes
`input_shapes` type hint in `to_quantized_edge_program()`. Fixes
`operators_not_to_delegate` assignment in partitioner.

### Test plan
Unit tests provided in
backends/nxp/tests/ir/converter/node_converter/test_conv_converter.py.


cc @digantdesai @JakeStevens @robert-kalmar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: nxp Issues related to NXP Neutron NPU delegation and code under backends/nxp/ release notes: nxp Changes to the NXP Neutron backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants