-
Notifications
You must be signed in to change notification settings - Fork 749
Arm Backend: Add support for ELU.default operator #12996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
agrima1304
wants to merge
423
commits into
pytorch:main
from
agrima1304:Arm-Backend-Add-support-for-ELU.default-operator
Closed
Changes from 1 commit
Commits
Show all changes
423 commits
Select commit
Hold shift + click to select a range
68df797
Import Fix
lucylq f478800
Fix pyre errors
izaitsevfb 5ab8f52
[Backend Tester] Add LSTM tests (#13238)
GregoryComer 6157370
[Backend Tester] Add avgpool tests (#13239)
GregoryComer fa059f2
[Backend Tester] Add maxpool tests (#13240)
GregoryComer eeaaad6
[Backend Tester] Add adaptive avgpool tests (#13241)
GregoryComer 04689e3
[Backend Tester] Add adaptive maxpool tests (#13242)
GregoryComer 732d9de
[Backend Tester] Add Qualcomm tester and register flow (#12739)
GregoryComer 5e05ca8
[Backend Tester] Add CSV report generation (#12741)
GregoryComer 252f6d3
Update build from source docs (#13210)
GregoryComer da3a5e5
Move to Span<EValue*> instead of EValue** in delegate interface
JacobSzwejbka f86285c
Update XNNPACK to 3131afe (#13234)
mcr229 3c9e77b
[Backend Tester] Add tensor error statistic reporting (#12809)
GregoryComer 6bd15dd
[Backend Tester] Report quantization and lowering times (#12838)
GregoryComer fbba9c2
Added ETDump to Wasm (#13304)
Conarnar 55b6e74
[Backend Tester] Report delegation statistics (#12846)
GregoryComer 172539b
[Backend Tester] Report PTE size (#13249)
GregoryComer 9b3b270
[ET-VK] Add optional blocklist and allowlist to vulkan partitioner to…
SS-JIA 5692111
[ET-VK][ez] Fix `vulkan_schema` interface and link Vulkan to pybindin…
SS-JIA acdc2e4
[ET-VK][examples] Create export script for Vulkan examples (#13294)
SS-JIA bb444f8
[ET-VK][CI] Add vulkan CI to test exporting and running models (#13295)
SS-JIA ab9eea8
Add csv and etdump files to gitignore (#13301)
mcr229 da39ab7
[Backend Tester] Add portable test flow (#13250)
GregoryComer 7989804
Static attention IO manager: fix causal mask bug for last input position
sxu 8cfb917
Update buck srcs for tester harness
GregoryComer 78ceda2
Check the number of inputs in Method::set_inputs (#13341)
shoumikhin 1a67f6c
Add TorchAO wrapper config to allow filter_fn for quantize_ (#13264)
abhinaykukkadapu c275f40
Add coreml quant recipes (#13265)
abhinaykukkadapu 400f31a
Fix 8w8a qat qconfig setting activations
navsud 07f8a4f
NXP backend: Return number of Neutron nodes in exception (#13126)
StrycekSimon 4f3b0e7
Enable strongly typed ops for deployment
mcremon-meta 353aff0
Arm backend: Run pytest model tests in parallell (#13322)
Erik-Lundell 012af23
Arm backend: Update arm_tester.py with generate_etrecord (#13370)
perheld d0cf666
Arm backend: Add logit decomposition pass and test (#13366)
tbergkvist 51efa78
Arm backend: Initial int16 extension (#13318)
per 587aee8
Arm backend: Remove submodule serialization_lib (#13182)
zingo fcd9538
Arm backend: Enable semihosting in build_executor_runner.sh (#13325)
Erik-Lundell 13eb90d
Fix torchao deps (#13107)
metascroy 9675ed7
Revert "Add TorchAO wrapper config to allow filter_fn for quantize_ (…
swolchok ea8331b
Introduce set input/output API to Module. (#13363)
shoumikhin 9493377
Update Llama Example Readme (#12956)
anzr299 c69af4c
Support for Running Arm Zephyr Toolchain on aarch64 Host Machines (#1…
BujSet da601a8
[ET-VK] Better work group sizes for matmul (#13378)
pytorchbot 33176f9
[ET-VK] Add mechanism to trigger command buffer re-encode only when n…
pytorchbot caf61ed
Introduce unload method API to Module. (#13364)
shoumikhin 4b6dfa0
[Backend Tester] Clean up a few test issues (#13258)
GregoryComer 2404756
[Backend Tester] Clean up report output (#13306)
GregoryComer db3fd27
[Backend Tester] Write report progressively (#13308)
GregoryComer d14d4c5
create and validate build_variables.bzl (#8326)
swolchok bc8a57f
[Backend Tester] Add subtest index field (#13311)
GregoryComer 4800771
[Backend Tester] Reduce log verbosity / spam (#13312)
GregoryComer 5983be9
NXP Backend: Add infrastructure for pre processing passes in edge dia…
MartinPavella b36b71a
Android preset (#11119)
kirklandsign dcdf9db
Qualcomm AI Engine Direct - Static Decoder Runner Support 16bit KV IO…
winskuo-quic 5a88920
Qualcomm AI Engine Direct - Phase out QCIR flow since it is no longer…
DannyYuyang-quic ee111e0
Bump Zephyr SDK Version in CI Image from v0.16.0 -> v0.17.2 (#13380)
BujSet 69ba8e9
[Backend Tester] Seed based on test name (#13313)
GregoryComer 58cfa13
Add support for strongly typed op_quantized_linear_out
ethansfng 1f1cc08
removed cron schedule runs to workflow until the GitHub token issue i…
nil-is-all 15b51ce
Set a doc build variable for executorch version (#13351)
kirklandsign b9ebbe7
lintrunner -a backends/qualcomm/CMakeLists.txt (#13396)
swolchok 39fd4b7
Update the comments of calculate_numeric_gap
Gasoonjia a64aa44
Use dtype agnostic op_cat implementation, add op_cat testcases
ethansfng dd69066
Arm backend: Allocate buffers with alignment (#13412)
Erik-Lundell 4c0f087
NXP backend: Improve cifarnet speed by removing the initial pading. (…
jirioc 4565979
Arm backend: Generate ETRecord from arm_aot_compiler (#13273)
zingo cc10e16
Use unlifted export pass to tag delegated constants (#13407)
pytorchbot 9bb3bbc
Add support for strongly typed op_quantized_relu (#13345)
ethansfng bd8f812
[Backend Tester] Add test flow CLI arg (#13360)
GregoryComer 1fec15c
Ensure the correct output data type for the full op.
eigen-k a61bb5a
Arm backend: Add decomposition and test for acos (#13414)
emmakujala 347afd1
Arm backend: Introduce documentation for VGF (#13369)
robell 9729c5d
Update ExecuTorchValue.mm (#13425)
shoumikhin dcecaab
Remove outdated comments and names in OSS
mcremon-meta d286910
Fix typo in op_quantized_relu_asym8u_asym8u
ethansfng 6627cbc
Fix unused-local-typedef issue
r-barnes a48dbfc
Refactoring Portable Operators to Standardize op_name Format (#12941)
BujSet 180baba
Android scheduled build add vulkan (#13428)
kirklandsign 1fdbafe
NXP Backend: Buck fixes for the PassManager
digantdesai 5f1dd11
Add support for strongly typed op_quantized_matmul, generalize dispat…
ethansfng 0d1c3dc
Arm backend: use tosa_ref_model only if installed
digantdesai 6421fd3
[executorch] Add TorchAO wrapper config to allow filter_fn for quanti…
pytorchbot bf7de85
[executorch] Add coreml quant recipes (#13441)
pytorchbot 0e90935
Static attention: do not specialize on input sequence length
sxu b4b1ac5
Fix typo in target file (#13443)
abhinaykukkadapu de54277
forward fix
cccclai 85b9577
Switch to conda-forge on MacOS (#13442)
huydhn 6be925a
use dtype agnostic implementation for non optimized op_permute_copy
ethansfng 881bd12
Extend `PyBundledModule` with `extension.BundledModule`
Gasoonjia 06366c5
Improve optimum coverage in ET (more models, xnnpack on mac) (#13400)
jackzhxng e489db3
build_variables.bzl: split PROGRAM_NO_PRIM_OPS_SRCS from EXECUTORCH_C…
swolchok 1221ace
Split quantized convolutions into NCHW and NHWC variants
mcremon-meta 9439d8a
[ez] Fix idx in duplicate_constant_node_pass (#13461)
SS-JIA 6b7001d
build_variables.bzl: split PLATFORM_SRCS from EXECUTORCH_CORE_SRCS (#…
swolchok d93e407
Arm backend: Add missing using-declaration in VGFBackend.cpp (#13460)
YufengShi-dudu 605b10c
[EZ] Replace `pytorch-labs` with `meta-pytorch`
ZainRizvi 0b6e14a
Allow for HOP to be in the etreord graph
kimishpatel e8b1082
build_variables.bzl: split PATTERN_SRCS from PORTABLE_KERNELS_SRCS (#…
swolchok 2c9436d
build_variables.bzl: make THREADPOOL_SRCS usable in BUCK (#13354)
swolchok 7dae813
build_variables.bzl: make MPS_BACKEND_SRCS usable in BUCK (#13355)
swolchok ccb00ee
build_variables.bzl: make XNNPACK_BACKEND_SRCS usable in BUCK (#13356)
swolchok b5c558b
build_variables.bzl: make CUSTOM_OPS_SRCS usable in BUCK (#13357)
swolchok 4f6b029
Expose portable ops as utils (add/stack)
manuelcandales d5d91bc
Introduce out shape utils (add/stack) (#13199)
manuelcandales 6d00d37
switch top-level ExecuTorch build from executorch_srcs.cmake to build…
swolchok 724dcb1
Alternative format specifier for %zd
cmt0 b8ab343
[ET-VK] Move rotary embedding custom op to be handled via graph pass …
pytorchbot dc7f9de
Qualcomm AI Engine Direct - Static LLM Decoder Refactor (#13314)
winskuo-quic 9a5af57
[ET-VK] Enable IntxWeightOnlyConfig (#13466)
pytorchbot ed4e59f
Add aten::_upsample_bilinear2d_aa.out (#13458)
mergennachin 29256b0
Add support for strongly typed conv_nchw and conv_nhwc
ethansfng a06b3da
Arm backend: Add dim_order_ops:: to the auto gen_oplist generations (…
zingo 3d30f7f
Arm backend: Replace .export_for_training with .export (#13280)
AdrianLundell fce39c0
Arm backend: Move TOSA operators to dialect (#13408)
per 6c506ae
Arm backend: Add example linkerscripts for U55/U85 (#13404)
perheld d99c9d2
Arm backend: Add cumsum support (#13457)
AdrianLundell 892db7a
Arm backend: Add partial vulkan runtime support for VgfPipeline (#13471)
YufengShi-dudu fc25fd8
NXP backend: Remove optimization in fuse_quanitze_into_preceding_ops.…
roman-janik-nxp 543a3c5
NXP backend: Remove optimization in prune_cast_operators.py (#13377)
roman-janik-nxp af656dc
NXP backend: Remove optimization in prune_reshape_operators.py (#13413)
roman-janik-nxp 4f4c34b
Remove outdated NCHW to NHWC pass and rename the current one to Repla…
mcremon-meta d210198
NXP backend: Improve target support checks. (#13367)
MartinPavella f287e0a
Add a default image prefiller implementation
larryliu0820 1680283
Switch non-top-level ExecuTorch builds (size test, examples, etc.) fr…
swolchok 9e38ee1
Enable BNNS copy for FP16 to FP32
metascroy e1e3933
Fix test-binary-size-linux -Wsign-compare failure with c10::irange (#…
swolchok 455071c
Arm backend: Update tosa dialect buck file
digantdesai 0d039c9
[ET-VK][ez] Move execute node threshold calculation from `prepare_pip…
pytorchbot 259aa8b
[ET-VK] Runtime support for NamedDataMap (#13498)
pytorchbot bc5d91c
[ET-VK][AOT] Serialize constant tensors via NamedDataMap (#13499)
pytorchbot 2b7e058
[ET-VK] Allocate memory for weight and activation tensors lazily (#13…
pytorchbot 6b5f73b
[ET-VK][ez] Fix erroneous cherry-pick bot merge (#13512)
SS-JIA 0c86282
Arm backend: Add support for QAT+per-channel combo (#13511)
oscarandersson8218 93eb208
Reset Temp Allocator after each use
cmt0 be3b509
Stop validating that build_variables.bzl matches buck-generated execu…
swolchok b0d5391
Only support int8 and quant dtypes for quant operators (#11685)
aaron-ang fc00827
Stop looking for buck2 in the top-level ExecuTorch build (#13393)
swolchok 9d64ccf
Summary: Add MCU model script to validate and run the models (#13439)
psiddh 293072c
Add a generic multimodal runner (#13166)
larryliu0820 d85205d
Force -O3 for executorch op_div.cpp in clang 19
akrieger 55dfc90
Remove NTSTATUS cast
SamGondelman ae6d536
Clean up apparently-unnecessary mentions of BUCK2 in scripts (#13394)
swolchok 38b86a3
Qwen and Phi-4-Mini targets (#13449)
rohansjoshi 8f286f3
Wrap LLM runner tests in anonymous namespace (unbreak unittest-releas…
swolchok 2bc6b0d
Delete extract_sources.py and cmake_deps.toml (#13395)
swolchok 171451b
[Backend Tester] Add nightly CI job for XNNPACK (#13390)
GregoryComer 09a4511
Add Model Profiling Automation Script (#13493)
leafs1 ee9f94c
Remove unused PROGRAM_SCHEMA_SRCS from build_variables.bzl (#13432)
swolchok f621114
removed lines of cron schedule runs until the github token issue is f…
nil-is-all 4313e48
[Backend Tester] Mark adaptive avgpool2d as an unsupported portable o…
GregoryComer 02a4657
[Backend Tester] Run Vulkan tests in nightly CI (#13445)
GregoryComer 86c9ee1
[Backend Tester] Run Core ML tests in nightly CI (#13446)
GregoryComer 4174b03
Improve softmax perf when transpose is not needed
mcremon-meta a08eb08
Remove other extensions' source files from EXTENSION_TRAINING_SRCS an…
swolchok f013ba4
Fix `buck query //extension/flat_tensor:` in OSS (#13484)
swolchok 38ba8cf
Qualcomm AI Engine Direct - Remove input_list dependencies (#13411)
chenweng-quic faff634
Qualcomm AI Engine Direct - GA Static Smollm2 (#13406)
chenweng-quic c06b947
Remove pinning to ET commit for Zephyr CI job (#13388)
BujSet 58ddf4f
Refactor pybinding unit test
cccclai fe255a2
Update project.pbxproj (#13537)
shoumikhin 4797f2e
Fix Olmo trunk test and skip XNNPack trunk tests on mac (#13528)
jackzhxng ab4fd57
Re-enable model tests with recipes for xnnpack backend (#13519)
abhinaykukkadapu 4c084e8
QNN Llama Runner implement IRunner (#13171)
rohansjoshi 7d13b2e
Arm backend: Remove get_output_nodes from runner_utils. (#13417)
AdrianLundell 88642c0
Arm backend: Fix for combo neg(x)+1 + tests (#13517)
wwwind 2eb7d7e
Adding smollm2 to examples/models/__init__.py (#13514)
SaoirseARM ba6a40c
Arm backend: Add limited support for fish shell in setup_path.sh (#13…
zingo 1f63c65
Arm backend: Update examples/arm/README.md (#13546)
zingo 58b3199
Arm backend: Added VGF minimal example (#13545)
SaoirseARM c5a8e8f
Temporarily disable test-models-arm-zephyr (#13548)
digantdesai 8b3261f
Fix buck cquery //extension/llm/runner: in OSS (#13527)
swolchok 075988d
Remove unused sleef.h which breaks cross-compilation on windows (#13079)
ykhrustalev e9754ab
Bump the PyTorch pin to 20250811 (#13334)
swolchok af9f8a5
Link QNN backend to pybinding lib when built (#13467)
GregoryComer c9f0159
Split on dilation for strongly typed convs
ethansfng 5da6516
Fix sigmoid operator to support boolean tensor inputs (#13515)
notkisk 19ea8a6
Fix method meta error in wasm build (#13496)
Conarnar 9ed517c
[Backend Tester] Add test flows for QNN (#13469)
GregoryComer cb4eeb4
[Backend Tester] Add additional quantized test flows for XNNPACK and …
GregoryComer 5c44446
[Backend Tester] Add markdown summary in CI (#13535)
GregoryComer d638703
Add support for strongly typed quantized_op_add
ethansfng 4019da4
Whisper audio processor
rohansjoshi 454b8a1
Non-fatal error when ET_SWITCH encounters unsupported dtype
manuelcandales 6fc8ede
Update test_remove_unused_parameters_pass.py (#13563)
digantdesai cc0609b
Fix memory planning for greed with heuristic algo.
hsharma35 45765ae
Input position accessor for static attention IO manager
sxu f73d44d
Unbreak build-benchmark-app (apple) after pin bump in #13334 (#13582)
swolchok 65dc152
Summary: Add ExecutorchRuntimeException: Throw relevant exceptions f…
psiddh 7616da9
Update tokenizer to include tekken implementation (#13601)
mergennachin 39afccc
Add missing backslashes in example run section (#13603)
Conarnar e610f23
Make IOManager use Module instead of Method. (#13542)
shoumikhin 391ae3c
Access Method directly from TrainingModule. (#13602)
shoumikhin 99b4216
Make TensorPtr constructor check the data dize matches the shape. (#1…
shoumikhin 5997ee3
Add set_outputs() API. (#13609)
shoumikhin 80adad5
Create stale.yml workflow to label stale PRs (#13565)
nil-is-all 619bc30
Add get_output API. (#13610)
shoumikhin 20e60bf
Set an empty EValue input for models that expect None arg. (#13621)
shoumikhin bdea7d0
Split on depthwise for strongly typed convs
ethansfng 88588bf
migrate all test_aten_ops to facto
zonglinpeng 9c0280c
fix MM nullptr from zero bias
zonglinpeng 3bb031b
Call .detach() in static attention cache update helper
sxu 64d88aa
Event Tracer Constraint
cmt0 0d4fd84
[ET-VK][ez] Fix partitioner logic of finding keepdim arg of reduce op…
SS-JIA bb0ec6e
[ET-VK][ez] Support grouped convolutions (#13599)
SS-JIA e98da6d
[ET-VK][ez] Use XNNPACK's FuseBatchNorm pass (#13600)
SS-JIA 67b48c3
[ET-VK][testing] Add scripts to facilitate operator testiing (#13593)
SS-JIA fb99e23
[ET-VK][ez] Consolidate tensor metadata calculation + buffer binding …
SS-JIA 3a8edfd
[ET-VK] Introduce `BufferMetadata` GLSL struct to abstract tensor lay…
SS-JIA 1653dbf
[ET-VK][ez] Allow high dimensional tensors (for buffer storage) (#13596)
SS-JIA 4f7871a
[ET-VK] High dim tensor support for view, unsqueeze, squeeze, clone (…
SS-JIA c91401e
[Core ML] Improve asset management (#13560)
cymbalrush 6142858
Rename stale to stale.yml (#13619)
nil-is-all 4298ff1
fix mismatch sub dtype (#13447)
cccclai 50b7913
NXP backend: Add support for the `aten.cat` operator. (#13505)
MartinPavella 044bdcd
NXP backend: Add implementation of Tanh operator converter (#13510)
MartinPavella 6f05c35
Arm backend: Dont try to fuse const for TOSA ops (#13575)
per 9b7c80a
NXP backend: Fix `tanh` merge conflict. (#13626)
MartinPavella 1dba47f
Cortex_m backend: Loosen edge op check. (#13550)
Erik-Lundell bf2abab
NXP backend: Use zero point for quantized padding. (#13576)
MartinPavella 4c510f1
Fix aten.amax lowering issue (#13381)
cccclai f55769d
Update coremltools to 9b1 (#13614)
metascroy 7854fe7
Add check_for_installed_private_headers_in_cmake_out (#13485)
swolchok 4e316d4
NXP backend: Add support for conversion of Conv1D operator (#13549)
roman-janik-nxp f9593d2
Update cpuinfo pin to latest (#13624)
pytorchbot 290a8f5
Added JS bindings for tokenizers library (#13566)
Conarnar ed11370
Run all periodic models when ciflow/periodic label is present (#13634)
shoumikhin 3c84d53
Add support for data path in iOS (#13620)
lucylq fbff62e
Summary: Follow up fix to pr#13526 (#13640)
psiddh fd921c4
Allow none and string input types for Method (#13645)
shoumikhin ecf5be2
Resurface low level runtime API page (#13651)
mergennachin 315c837
Fully enable the stale PR workflow (#13656)
nil-is-all be22ad5
Qualcomm AI Engine Direct - Scripts and accuracy improvement for Qwen…
winskuo-quic dc4ff25
Qualcomm AI Engine Direct - Improve GA Static Phi-4-mini accuracy (#1…
shewu-quic c72accb
Qualcomm AI Engine Direct - Fix broken unpacking in T5 dataset loadin…
DannyYuyang-quic 1797ba1
Fix error reporting in Windows preset build job (#13247)
GregoryComer 1feb7c7
Fix devtools CMake build failure on Windows (#13251)
GregoryComer ceb6f32
Create Windows CMake preset (#13257)
GregoryComer b78e768
Temporarily disable windows preset build in CI (#13669)
GregoryComer 68a9a42
Increase binary size limit by 8 bytes (#13671)
shoumikhin c92def6
NXP backend: Remove IR optimization to remove dead branches. (#13574)
MartinPavella 80d1407
Inline requantize kernels
mcremon-meta aae7baa
Smollm targets
rohansjoshi 01ca904
Override unload_method in training_module to erase the tensors pointi…
silverguo 4df836d
Disable mm + add -> addmm fusion if added tensor rank >2
hsharma35 9d6a7f2
Fix bad optimized kernel for add.
hsharma35 e1cd63e
Allow zero-element inputs for method.
hsharma35 f93d524
Arm backend: Use cmake for building in Ethos-U jupyter example (#13630)
AdrianLundell 36cdaec
NXP backend: Add MobileNetV2 example model and test (#12892)
StrycekSimon 7bb115b
Arm Backend: Add support for ELU.default operator
agrima1304 c9cbad7
Arm Backend: Add support for ELU.default operator
agrima1304 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes.arm_pass_utils import create_node | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
| from executorch.exir.pass_base import ExportPass, PassResult | ||
|
|
||
|
|
||
| class ConvertELUParamsPass(ExportPass): | ||
| """ | ||
| Pass to convert the input_scale kwarg of ELU operator from float to | ||
| int. | ||
|
|
||
| It has been set to 2 as the outputs seem to stay the same regardless of what | ||
| the value of input_scale is, as long as that value is not 1. | ||
| """ | ||
|
|
||
| def call(self, graph_module: torch.fx.GraphModule): | ||
| modified_graph = False | ||
| graph = graph_module.graph | ||
| node_list = graph.find_nodes( | ||
| op="call_function", target=exir_ops.edge.aten.elu.default | ||
| ) | ||
| for node in node_list: | ||
| with graph.inserting_after(node): | ||
| replace_node = create_node(graph, exir_ops.edge.aten.elu.default) | ||
| replace_node.args = ( | ||
| node.args[0], | ||
| int(node.args[1]) if len(node.args) > 1 else 1, | ||
| ) | ||
| updated_kwargs = dict(node.kwargs) | ||
| updated_kwargs["input_scale"] = int(2) | ||
| replace_node.kwargs = updated_kwargs | ||
|
|
||
| node.replace_all_uses_with(replace_node) | ||
| graph.erase_node(node) | ||
|
|
||
| modified_graph = True | ||
| if modified_graph: | ||
| graph_module.recompile() | ||
| graph_module = super().call(graph_module).graph_module | ||
|
|
||
| return PassResult(graph_module, modified_graph) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| from executorch.backends.arm._passes import ArmPass | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
|
|
||
| edge_elu_ops = (exir_ops.edge.aten.elu.default,) | ||
| aten_elu_ops = (torch.ops.aten.elu.default, torch.ops.aten.elu_.default) | ||
|
|
||
|
|
||
| def get_elu_decomposition(op) -> tuple: | ||
| """ | ||
| Returns the decomposition of the given aten.elu operation into | ||
| its equivalent TOSA-supported operations | ||
|
|
||
| This handles both edge dialect ops and core PyTorch ops. The decomposition strategy | ||
| is: | ||
| elu(x, y) → where(greater_or_eq(x, 0), (exp(x)-1), x) | ||
|
|
||
| Returns: | ||
| A tuple (exp_op, sub_op, ge_op, where_op) corresponding to the appropriate operator | ||
| overloads for the input op. | ||
|
|
||
| Raises: | ||
| RuntimeError: If the provided operator is not a supported elu variant. | ||
| """ | ||
|
|
||
| if op in edge_elu_ops: | ||
| return ( | ||
| exir_ops.edge.aten.add.Scalar, | ||
| exir_ops.edge.aten.exp.default, | ||
| exir_ops.edge.aten.ge.Scalar, | ||
| exir_ops.edge.aten.where.self, | ||
| exir_ops.edge.aten.mul.Scalar, | ||
| ) | ||
|
|
||
| if op in aten_elu_ops: | ||
| return ( | ||
| torch.ops.aten.add.Scalar, | ||
| torch.ops.aten.exp.default, | ||
| torch.ops.aten.ge.Scalar, | ||
| torch.ops.aten.where.self, | ||
| torch.ops.aten.mul.Scalar, | ||
| ) | ||
agrima1304 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| raise RuntimeError(f"Can't get elu decomposition for op {op}") | ||
|
|
||
|
|
||
| class DecomposeEluPass(ArmPass): | ||
| """ | ||
| A transformation pass that decomposes unsupported 'aten.elu' operations | ||
| into a combination of supported TOSA-equivalent operations. | ||
|
|
||
| Since TOSA does not provide a native ELU operator, this pass rewrites: | ||
| elu(x) → where(greater_or_eq(x, 0), (alpha*(exp(x)-1)), x) | ||
|
|
||
| Supported input ops: | ||
| - aten.elu(x) | ||
| - aten.elu_(x) | ||
| - exir_ops.edge.aten.elu.Tensor(x) | ||
|
|
||
| These are replaced with: | ||
| - aten.exp or exir_ops.edge.aten.exp | ||
| - aten.sub.Scalar or exir_ops.edge.aten.sub.Scalar | ||
| - aten.ge.Scalar or exir_ops.edge.aten.ge.Scalar | ||
| - aten.where.self or exir_ops.edge.aten.where.self | ||
| - aten.mul.Scalar or exir_ops.edge.aten.mul.Scalar | ||
| """ | ||
|
|
||
| def call_operator(self, op, args, kwargs, meta): | ||
| if op not in (edge_elu_ops + aten_elu_ops): | ||
| return super().call_operator(op, args, kwargs, meta, updated=False) | ||
|
|
||
| ( | ||
| add_op, | ||
| exp_op, | ||
| ge_op, | ||
| where_op, | ||
| mul_op, | ||
| ) = get_elu_decomposition(op) | ||
|
|
||
| input = args[0] | ||
| alpha = int(args[1]) if len(args) > 1 else 1 | ||
|
|
||
| exp_node = super().call_operator(exp_op, (input,), {}, meta, updated=True) | ||
| sub_node = super().call_operator( | ||
| add_op, (exp_node, -1.0), {}, meta, updated=True | ||
| ) | ||
| mul_node = super().call_operator( | ||
| mul_op, (sub_node, alpha), {}, meta, updated=True | ||
| ) | ||
| ge_node = super().call_operator(ge_op, (input, 0.0), {}, meta, updated=True) | ||
| where_node = super().call_operator( | ||
| where_op, (ge_node, input, mul_node), {}, meta, updated=True | ||
| ) | ||
|
|
||
| return where_node | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from executorch.backends.arm.test import common | ||
| from executorch.backends.arm.test.tester.test_pipeline import ( | ||
| EthosU55PipelineBI, | ||
| EthosU85PipelineBI, | ||
| TosaPipelineBI, | ||
| TosaPipelineMI, | ||
| ) | ||
|
|
||
| test_data_suite = { | ||
agrima1304 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # (test_name, test_data) | ||
| "zeros_default": lambda: (1.0, torch.zeros(1, 10, 10, 10)), | ||
| "ones_default": lambda: (1.0, torch.ones(10, 10, 10)), | ||
| "rand_default": lambda: (1.0, torch.rand(10, 10) - 0.5), | ||
| "randn_pos_default": lambda: (1.0, torch.randn(1, 2, 3, 3) + 10), | ||
| "randn_neg_default": lambda: (1.0, torch.randn(2, 4, 3) - 10), | ||
| "ramp_default": lambda: (1.0, torch.arange(-16, 16, 0.2)), | ||
| "large_pos_default": lambda: (1.0, torch.randn(3, 3) * 1e6 + 1e7), | ||
| "large_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e5, 1e8)), | ||
| "small_pos_default": lambda: (1.0, torch.empty(5).uniform_(1e-8, 1e-5)), | ||
| "small_neg_default": lambda: (1.0, -torch.empty(5).uniform_(1e-8, 1e-5)), | ||
| "zeros_custom": lambda: (2.0, torch.zeros(1, 10, 10, 10)), | ||
| "ones_custom": lambda: (2.0, torch.ones(10, 10, 10)), | ||
| "rand_custom": lambda: (2.0, torch.rand(10, 10) - 0.5), | ||
| "randn_pos_custom": lambda: (2.0, torch.randn(1, 3, 3) + 10), | ||
| "randn_neg_custom": lambda: (2.0, torch.randn(1, 2, 4, 3) - 10), | ||
| "ramp_custom": lambda: (2.0, torch.arange(-16, 16, 0.2)), | ||
| "large_pos_custom": lambda: (2.0, torch.randn(3, 3) * 1e6 + 1e7), | ||
| "large_neg_custom": lambda: (2, -torch.empty(5).uniform_(1e5, 1e8)), | ||
| "small_pos_custom": lambda: (2.0, torch.empty(5).uniform_(1e-8, 1e-5)), | ||
| "small_neg_custom": lambda: (2.0, -torch.empty(5).uniform_(1e-8, 1e-5)), | ||
| } | ||
|
|
||
|
|
||
| class Elu(nn.Module): | ||
| aten_op = "torch.ops.aten.elu.default" | ||
| exir_op = "executorch_exir_dialects_edge__ops_aten__elu_default" | ||
|
|
||
| def __init__(self, input_alpha: float = 1.0): | ||
| super().__init__() | ||
| self.elu = torch.nn.ELU(alpha=input_alpha) | ||
|
|
||
| def forward(self, input_: torch.Tensor): | ||
| return self.elu(input_) | ||
|
|
||
|
|
||
| input_t1 = Tuple[torch.Tensor] | ||
|
|
||
|
|
||
| @common.parametrize("test_module", test_data_suite) | ||
| def test_elu_tosa_MI(test_module: input_t1): | ||
| alpha, test_data = test_module() | ||
| pipeline = TosaPipelineMI[input_t1]( | ||
| Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.parametrize("test_module", test_data_suite) | ||
| def test_elu_tosa_BI(test_module: input_t1): | ||
| alpha, test_data = test_module() | ||
| pipeline = TosaPipelineBI[input_t1]( | ||
| Elu(alpha), (test_data,), aten_op=Elu.aten_op, exir_op=Elu.exir_op | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.XfailIfNoCorstone300 | ||
| @common.parametrize("test_module", test_data_suite) | ||
| def test_elu_u55_BI(test_module: input_t1): | ||
| alpha, test_data = test_module() | ||
| pipeline = EthosU55PipelineBI[input_t1]( | ||
| Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op | ||
| ) | ||
| pipeline.run() | ||
|
|
||
|
|
||
| @common.XfailIfNoCorstone320 | ||
| @common.parametrize("test_module", test_data_suite) | ||
| def test_elu_u85_BI(test_module: input_t1): | ||
| alpha, test_data = test_module() | ||
| pipeline = EthosU85PipelineBI[input_t1]( | ||
| Elu(alpha), (test_data,), aten_ops=Elu.aten_op, exir_ops=Elu.exir_op | ||
| ) | ||
| pipeline.run() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value of input_scale is 1.0, however using the default value resulted in a type error. When passing 1 as an int, it was overridden by the default value (since both values are 1 and therefore equivalent ). So input_scale had to be changed to an int that is not 1.