-
Notifications
You must be signed in to change notification settings - Fork 752
Arm backend: Support channels-last input and output #14259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
861e98a
c996232
bccaa2a
9a4440e
cc7d899
843e600
adec5aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -249,15 +249,6 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { | |
| handles.inputs->io[i].elem_size); | ||
| return Error::InvalidProgram; | ||
| } | ||
| supported = executorch::runtime::is_contiguous_dim_order( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this implies it can handle anything i.e. transpose op is inserted if it was needed. But what about asserting expectations. I.e. if user exported with NCHW and we inserted a transpose_to_nhwc AoT, what if now user supplied NHWC (instead of assumed NCHW), shouldn't we validate since we don't "check and optionally transpose" at runtime.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point and I agree, however since we are past the branch cutoff date and we need this patch to unblock a major use case for us, may I ask to ignore this for now and fix this in a later PR?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries. I assumed that and stamped already :)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the stamp mean we can merge, or do we still wait och Meta to merge?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. haha good I was in the non-GA mode already where stamp --> you merge --> Internal failure --> we revert But looking at the activity from @mergennachin he is, rightfully, still in GA mental mode for this GA critical PR. So if the internal CI is clean, he or I can merge this.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, Im happy and you handle it fast so no problem, I just want to avoid an n"o one does it" situation 😆 as there might be a merge/sync to 1.0 coming up. And I also dont expect any PR not 1.0 milestone tagged to be be merged. If so its just a bonus. |
||
| tensor_in.dim_order().data(), tensor_in.dim()); | ||
| if (!supported) { | ||
| ET_LOG( | ||
| Error, | ||
| "Input %d expected contiguous dim_order, but got non-contiguous dim_order", | ||
| i); | ||
| return Error::InvalidProgram; | ||
| } | ||
|
|
||
| // Select a compatible copy routine including checking for input layouts | ||
| // which require permutation. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| # Copyright 2024-2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
|
|
||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| from executorch.backends.arm.test import common | ||
|
|
||
| from executorch.backends.arm.test.tester.test_pipeline import ( | ||
| EthosU55PipelineINT, | ||
| EthosU85PipelineINT, | ||
| TosaPipelineFP, | ||
| TosaPipelineINT, | ||
| ) | ||
|
|
||
|
|
||
| input_t1 = Tuple[torch.Tensor] # Input x | ||
|
|
||
|
|
||
| class ChannelsLastInput(torch.nn.Module): | ||
| """ | ||
| Test a complex case with (channels last, channels first) input, | ||
| and (channels first, channels last) output. | ||
| """ | ||
|
|
||
| inputs: input_t1 = ( | ||
| torch.arange(1, 25, dtype=torch.float32) | ||
| .reshape((1, 2, 3, 4)) | ||
| .to(memory_format=torch.channels_last), | ||
| torch.arange(1, 25, dtype=torch.float32).reshape((1, 2, 3, 4)), | ||
| ) | ||
|
|
||
| def forward(self, x, y): | ||
| x = x * x | ||
| return y, x | ||
|
|
||
|
|
||
| class ChannelsFirstOutput(torch.nn.Module): | ||
| """ | ||
| Test coverting to channels_first inside the delegate. | ||
| """ | ||
|
|
||
| inputs: input_t1 = ( | ||
| torch.arange(1, 25, dtype=torch.float32) | ||
| .reshape((1, 2, 3, 4)) | ||
| .to(memory_format=torch.channels_last), | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| x = x.clone(memory_format=torch.contiguous_format) * x | ||
| return x | ||
|
|
||
|
|
||
| class ChannelsLastOutput(torch.nn.Module): | ||
| """ | ||
| Test changing of dim_order inside the delegate. | ||
| """ | ||
|
|
||
| inputs: input_t1 = (torch.arange(1, 9, dtype=torch.float32).reshape((1, 2, 2, 2)),) | ||
|
|
||
| def forward(self, x): | ||
| x = x * x | ||
| x = x.clone(memory_format=torch.channels_last) | ||
| return x | ||
|
|
||
|
|
||
| class ChannelsLastInsidePartition(torch.nn.Module): | ||
| """ | ||
| Test dim_order changes inside the partiton, but no dim_order changes at input/output. | ||
| """ | ||
|
|
||
| inputs: input_t1 = (torch.randn((1, 2, 3, 3)),) | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.conv2d = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=(3, 3)) | ||
|
|
||
| def forward(self, x): | ||
| return ( | ||
| self.conv2d(x.clone(memory_format=torch.channels_last)).clone( | ||
| memory_format=torch.contiguous_format | ||
| ) | ||
| * 1 | ||
| ) | ||
|
|
||
|
|
||
| test_modules = { | ||
| "channels_last_input": ChannelsLastInput, | ||
| "channels_first_output": ChannelsFirstOutput, | ||
| "channels_last_output": ChannelsLastOutput, | ||
| "channels_last_inside_partition": ChannelsLastInsidePartition, | ||
| } | ||
|
|
||
|
|
||
| @common.parametrize("module", test_modules) | ||
| def test_dim_order_tosa_FP(module): | ||
| pipeline = TosaPipelineFP[input_t1](module(), module.inputs, []) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.parametrize("module", test_modules) | ||
| def test_dim_order_tosa_INT(module): | ||
| pipeline = TosaPipelineINT[input_t1]( | ||
| module(), module.inputs, [], symmetric_io_quantization=True | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.XfailIfNoCorstone300 | ||
| @common.parametrize("module", test_modules) | ||
| def test_dim_order_u55_INT(module): | ||
| pipeline = EthosU55PipelineINT[input_t1](module(), module.inputs, []) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.XfailIfNoCorstone320 | ||
| @common.parametrize("module", test_modules) | ||
| def test_dim_order_u85_INT(module): | ||
| pipeline = EthosU85PipelineINT[input_t1](module(), module.inputs, []) | ||
| pipeline.run() |
Uh oh!
There was an error while loading. Please reload this page.