diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index ed558e2cb4b..d983b9bfffc 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -84,6 +84,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: common_persistent, ) + # TBD: Find a principled way to merge node.meta across all fused node + # For now, i specifically transfer over the TosaSpecialDtype.meta_key() of the rep_node + if TosaSpecialDtype.meta_key() in rep_node.meta: + common_node.meta[TosaSpecialDtype.meta_key()] = rep_node.meta[ + TosaSpecialDtype.meta_key() + ] + # Replace uses and delete duplicates for node, _ in nodes_tensors: node.replace_all_uses_with(common_node) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index b94ea23c6c1..01cafad13d0 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -18,6 +18,7 @@ def define_arm_tests(): "ops/test_addmm.py", "ops/test_avg_pool2d.py", "ops/test_cat.py", + "ops/test_conv2d.py", "ops/test_linear.py", "ops/test_mul.py", "ops/test_permute.py",