From c91a6c077880aebec89bbd97c87db9671d776443 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 20 Mar 2025 15:28:20 +0100 Subject: [PATCH 1/2] Fix retracing in FuseViewCopyTransform Since the pass can change shapes of ops, the graph needs to be retraced to show this in node.meta["val"]. Signed-off-by: Erik Lundell Change-Id: Ief24fe9d11384a2d0f64f0d91070eca7b0caf18e --- backends/transforms/fuse_view_copy.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index 22e20d1c88b..c740515cdcc 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,13 +12,14 @@ from executorch.exir.pass_base import ExportPass, PassResult -def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph: +def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: """ Find chains of view_copy nodes and merge them into one view_copy node. Only merges view_copy nodes that are not used by any other nodes. """ ops = exir_ops.edge view_op = ops.aten.view_copy.default + modified = False for node in graph.nodes: if node.op == "call_function" and node.target == view_op: # find ending view_copy node in chain @@ -35,29 +37,36 @@ def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph: new_args = (node.args[0], end_node.args[1]) node.args = new_args end_node.replace_all_uses_with(node) + modified = True graph.eliminate_dead_code() - return graph + return graph, modified -def remove_noop_view_copy(graph: torch.fx.Graph) -> torch.fx.Graph: +def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: """ Remove view_copy nodes that are no-ops. """ ops = exir_ops.edge view_op = ops.aten.view_copy.default + modified = False for node in graph.nodes: if node.op == "call_function" and node.target == view_op: input_shape = list(node.args[0].meta["val"].shape) target_shape = node.args[1] if input_shape == target_shape: node.replace_all_uses_with(node.args[0]) + modified = True graph.eliminate_dead_code() - return graph + return graph, modified class FuseViewCopyTransform(ExportPass): def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph = merge_view_copy_chains(graph_module.graph) - graph_module.graph = remove_noop_view_copy(graph_module.graph) - return PassResult(graph_module, True) + graph_module.graph, merge_modified = merge_view_copy_chains(graph_module.graph) + graph_module.graph, noop_modified = remove_noop_view_copy(graph_module.graph) + modified = merge_modified or noop_modified + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) From 9467d4f50882bd4f9c1136e78c26b1b02d5d59e9 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 7 Mar 2025 17:16:50 +0100 Subject: [PATCH 2/2] Arm backend: Add ComputeConstantOpsAOT pass Operators that output tensors based on constant args are pre-computed and added as buffers. - The pass currently supports full, arange, linspace, and eye. - Remove some logic for full now handled by the pass - Rename FuseConstantOpsPass to FuseConstantArgsPass and do minor improvements Signed-off-by: Erik Lundell Change-Id: I744e2583a9ed011e350cfaa43410902bd9e54292 --- backends/arm/_passes/arm_pass_manager.py | 20 +- backends/arm/_passes/cast_int64_pass.py | 55 +++--- .../fold_qdq_with_annotated_qparams_pass.py | 17 +- .../arm/_passes/fuse_constant_ops_pass.py | 174 ++++++++++++------ .../tosa_supported_operators.py | 18 +- backends/arm/test/models/test_conformer.py | 1 - .../arm/test/models/test_nn_functional.py | 10 +- backends/arm/test/ops/test_arange.py | 135 ++++++++++++++ backends/arm/test/ops/test_full.py | 6 - .../passes/test_fuse_constant_ops_pass.py | 11 +- backends/arm/tosa_partitioner.py | 2 + backends/arm/tosa_quant_utils.py | 2 +- 12 files changed, 322 insertions(+), 129 deletions(-) create mode 100644 backends/arm/test/ops/test_arange.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7ec04ea0844..b2a9ddb710d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -55,7 +55,10 @@ RetraceFoldedDtypesPass, ) from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass -from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOT, + FuseConstantArgsPass, +) from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found] FuseQuantizedActivationPass, ) @@ -121,21 +124,23 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) + self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) + self.add_pass(MatchArgRanksPass(exported_program)) + self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(CastInt64ToInt32Pass(exported_program)) - self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(KeepDimsFalseToSqueezePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantOpsPass(exported_program)) + self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) @@ -166,21 +171,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) + self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) + self.add_pass(MatchArgRanksPass(exported_program)) + self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(CastInt64ToInt32Pass(exported_program)) - self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(KeepDimsFalseToSqueezePass()) self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantOpsPass(exported_program)) + self.add_pass(FuseConstantArgsPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index dffa4c199a4..199323e363d 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -8,7 +8,6 @@ import logging import torch -from executorch.backends.arm._passes.arm_pass_utils import is_param_node from executorch.exir.pass_base import ExportPass, PassResult from torch._export.utils import is_buffer @@ -25,35 +24,37 @@ def __init__(self, exported_program: torch.export.ExportedProgram): super(CastInt64ToInt32Pass, self).__init__() self.exported_program = exported_program + def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node): + if torch.min(tensor) < torch.iinfo(torch.int32).min: + raise RuntimeError( + f"Node {node.name} has value < {torch.iinfo(torch.int32).min}" + ) + if torch.max(tensor) > torch.iinfo(torch.int32).max: + raise RuntimeError( + f"Node {node.name} has value > {torch.iinfo(torch.int32).max}" + ) + def _to_int32(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: fake_tensor = node.meta["val"] - if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): - if node.meta["val"].dtype == torch.int64 and is_param_node( - self.exported_program, node - ): - if is_buffer(self.exported_program, node): - node.meta["val"] = node.meta["val"].to(torch.int32) - buffer_name = ( - self.exported_program.graph_signature.inputs_to_buffers[ - node.name - ] - ) - buffer = self.exported_program.state_dict[node.name] - logger.warning( - f"Casting buffer {node.name} from torch.int64 to torch.int32" - f" defined in {node.meta['stack_trace']}" - ) - if torch.min(buffer) < torch.iinfo(torch.int32).min: - raise RuntimeError( - f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}" - ) - if torch.max(buffer) > torch.iinfo(torch.int32).max: - raise RuntimeError( - f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}" - ) - buffer_int32 = buffer.to(torch.int32) - self.exported_program.state_dict[buffer_name] = buffer_int32 + if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): + continue + if fake_tensor.dtype != torch.int64: + continue + if is_buffer(self.exported_program, node): + node.meta["val"] = fake_tensor.to(torch.int32) + buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ + node.name + ] + buffer = self.exported_program.state_dict[node.name] + self._assert_within_int32(buffer, node) + logger.warning( + f"Casting buffer {node.name} from torch.int64 to torch.int32" + f" defined in {node.meta.get('stack_trace','[no stack trace found]')}" + ) + buffer_int32 = buffer.to(torch.int32) + self.exported_program.state_dict[buffer_name] = buffer_int32 + continue def call(self, graph_module: torch.fx.GraphModule): self._to_int32(graph_module) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 7a965539f82..963759bfa6d 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -174,11 +174,8 @@ def call(self, graph_module: GraphModule) -> PassResult: class QuantizeOperatorArguments(ExportPass): """ - This pass makes sure that the arguments to full.default and clamp.default are quantized correctly. + This pass makes sure that the arguments to clamp.default are quantized correctly. More specifically, this pass: - - Makes sure the fill_value for full.default is quantized. This pass needs to be run before - the folding pass above to make sure that the retraced output of the full.default op is - the right dtype. - Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator. """ @@ -189,7 +186,6 @@ def call(self, graph_module: GraphModule) -> PassResult: n = cast(Node, n) if n.target not in { exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.full.default, }: continue @@ -200,16 +196,7 @@ def call(self, graph_module: GraphModule) -> PassResult: qargs = QuantArgs.from_operator(user.target, user.args) - if n.target == exir_ops.edge.aten.full.default: - if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: - # replace the node arg with a quantized dito and also set dtype - # to get the right output according to the Edge IR specification: - # exir/dialects/edge/edge.yaml:3596 - quantized_full_value = qargs.quantize_value(n.args[1]).item() - n.update_arg(1, quantized_full_value) - n.update_kwarg("dtype", qargs.dtype) - modified = True - elif n.target == exir_ops.edge.aten.clamp.default: + if n.target == exir_ops.edge.aten.clamp.default: # Quantize the min and max arguments of clamp, if they are not None min_val = n.args[1] max_val = None if len(n.args) <= 2 else n.args[2] diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 1fff7d76dfc..f37bc06d16a 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -8,6 +8,7 @@ import torch._export.utils from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, + get_first_fake_tensor, get_param_tensor, is_persistent_buffer, ) @@ -18,11 +19,12 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from torch.export.graph_signature import InputKind logger = logging.getLogger(__name__) -class FuseConstantOpsPass(ExportPass): +class FuseConstantArgsPass(ExportPass): """ Fuses ops with only placeholder parameters into one placeholder parameter node with the op pre-calulcated on its data. @@ -42,67 +44,38 @@ def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() self.exported_program = exported_program - def fuse_nodes(self, node) -> bool: + def _fuse_nodes(self, node) -> bool: """ Takes a node with only parameter inputs and replaces it with one constant tensor node with the operations already carried out on the data. """ - if node.target == exir_ops.edge.aten.full.default: - # Create data from args - size, fill_value = node.args - dtype = node.kwargs["dtype"] - data = torch.full(size, float(fill_value), dtype=dtype) + # Extract tensors and args from the node + data_list = [ + get_param_tensor(self.exported_program, input_node) + for input_node in node.all_input_nodes + ] - insert_pos = list(node.graph.nodes)[0] - else: - # Extract tensors and args from the node - - if len(node.all_input_nodes) == 0: - raise RuntimeError("No inputs found") + args = node.args[len(node.all_input_nodes) :] + kwargs = node.kwargs - data_list = [ - get_param_tensor(self.exported_program, input_node) - for input_node in node.all_input_nodes - ] + if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0: + for i in range(len(node.all_input_nodes)): + q_params = node.meta["input_qparams"][i] + data_list[i] = q_params.dequantize_value(data_list[i]) - args = node.args[len(node.all_input_nodes) :] - kwargs = node.kwargs + # Run the op on the extracted tensor + data = node.target(*data_list, *args, **kwargs) - if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0: - dequantize_op = ( - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default - ) + # Only fuse if the tensor does not get bigger. + if data.numel() > get_first_fake_tensor(node).numel(): + return False - for i in range(len(node.all_input_nodes)): - q_params = node.meta["input_qparams"][i] - data_list[i] = dequantize_op( - data_list[i], - q_params.scale, - q_params.zp, - q_params.qmin, - q_params.qmax, - q_params.dtype, - ) - - # Run the op on the extracted tensor - data = node.target(*data_list, *args, **kwargs) - - if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: - quantize_op = ( - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default - ) - q_params = node.meta["output_qparams"][0] - data = quantize_op( - data, - q_params.scale, - q_params.zp, - q_params.qmin, - q_params.qmax, - q_params.dtype, - ) + if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: + q_params = node.meta["output_qparams"][0] + data = q_params.quantize_value(data) - insert_pos = list(node.all_input_nodes)[0] + insert_pos = list(node.all_input_nodes)[0] # Make new node the same kind as the first constant input input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos) @@ -124,20 +97,17 @@ def fuse_nodes(self, node) -> bool: return True def call(self, graph_module): - modified = True + modified = False input_nodes_to_delete = [] for node in graph_module.graph.nodes: if node.op != "call_function": continue if node.target == torch.ops.tosa._table.default: continue - if node.target == exir_ops.edge.aten.repeat.default: - _, multiples = node.args - # Do not fuse if the repeat creates a larger output, i.e. any multiple > 1 - if any((multiple > 1 for multiple in multiples)): - continue input_nodes = node.all_input_nodes + if len(input_nodes) == 0: + continue input_nodes_constant = ( torch._export.utils.is_param(self.exported_program, input_node) or torch._export.utils.is_lifted_tensor_constant( @@ -152,9 +122,11 @@ def call(self, graph_module): if all(input_nodes_constant) and all(input_nodes_single_users): try: - self.fuse_nodes(node) - graph_module.recompile() # Recompile needed to catch chains of constant ops - input_nodes_to_delete.extend(input_nodes) + did_fuse = self._fuse_nodes(node) + modified |= did_fuse + if did_fuse: + graph_module.recompile() # Recompile needed to catch chains of constant ops + input_nodes_to_delete.extend(input_nodes) except Exception as e: logger.warning( f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}" @@ -168,3 +140,85 @@ def call(self, graph_module): graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) + + +class ComputeConstantOpsAOT(ExportPass): + """ + Evaluates call_functions that produce constant tensor outputs and replaces them with placeholders. + + Original: + state_dict = {} + def f(): + return torch.arange(0,10) + After pass: + state_dict = {node_name_pre_computed : torch.arange(0,10)} + def f(node_name_pre_computed): + return node_name_pre_computed + """ + + targeted_ops = [ + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.arange.start_step, + exir_ops.edge.aten.eye.default, + exir_ops.edge.aten.linspace.default, + ] + + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.exported_program = exported_program + + def compute_node_aot(self, node: torch.fx.Node) -> bool: + """ + Takes a node with only parameter inputs and replaces it with one constant tensor node with + the operations already carried out on the data. + """ + + # Create data from args + output_qparams = node.meta.get("output_qparams", None) + if output_qparams: + # If we have output_qparams, compute data in fp and quantize + data = node.target(*node.args) # type: ignore + output_qparams = output_qparams[0] + data = output_qparams.quantize_value(data) + else: + # If we don't have output_qparams, compute data using kwarg-specified dtype + data = node.target(*node.args, **node.kwargs) # type: ignore + + # Create new node + insert_pos = list(node.graph.nodes)[0] + input_kind = InputKind.BUFFER + persistent_buffer = True + + with node.graph.inserting_before(insert_pos): + const_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=node.graph, + kind=input_kind, + name=node.name + "_pre_computed", + data=data, + persistent_buffer=persistent_buffer, + ) + node.replace_all_uses_with(const_node) + + return True + + def call(self, graph_module): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target not in self.targeted_ops: + continue + try: + modified |= self.compute_node_aot(node) + except Exception as e: + logger.warning( + f"\nFailed to pre-compute op {node.name} due to exception:\n{str(e)}" + ) + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index dfd8024e4b3..6734b91a9ef 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -14,6 +14,7 @@ import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, ) @@ -142,6 +143,7 @@ def is_node_supported( exir_ops.edge.aten.logical_or.default, exir_ops.edge.aten.logical_xor.default, exir_ops.edge.aten.logical_not.default, + exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor, @@ -200,6 +202,8 @@ def is_node_supported( exir_ops.edge.aten.constant_pad_nd.default, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, + exir_ops.edge.aten.eye.default, + exir_ops.edge.aten.linspace.default, ] return supported @@ -441,16 +445,18 @@ def is_node_supported( ) -> bool: for input_node in node.all_input_nodes: - # We can cast constant placeholders AOT, not call_functions. + # We can cast constant placeholders and constant ops AOT, such int64 are ok. + # Otherwise, don't partition if one or more inputs are int64. if ( input_node.name in self.input_names or not input_node.op == "placeholder" ): tensor = get_first_fake_tensor(input_node) if tensor.dtype == torch.int64: - self.reporter.report_reject( - node, - f"Had int64 input {input_node.name} that couldn't be handled.", - ) - return False + if input_node.target not in ComputeConstantOpsAOT.targeted_ops: + self.reporter.report_reject( + node, + f"Had int64 input {input_node.name} that couldn't be handled.", + ) + return False return True diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 4ed203a964e..3d32454f8de 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -30,7 +30,6 @@ class TestConformer(unittest.TestCase): # for that is some assert ops are removed by passes in the # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { - "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1, "executorch_exir_dialects_edge__ops_aten_max_default": 1, "executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2, "executorch_exir_dialects_edge__ops_aten_where_self": 4, diff --git a/backends/arm/test/models/test_nn_functional.py b/backends/arm/test/models/test_nn_functional.py index b0a1e543ed3..40e83aa05c4 100644 --- a/backends/arm/test/models/test_nn_functional.py +++ b/backends/arm/test/models/test_nn_functional.py @@ -79,14 +79,20 @@ def forward(self, *args): @parametrize( - "test_data", module_tests, xfails={"max_pool1d": "ValueError: Invalid TOSA graph"} + "test_data", + module_tests, + xfails={ + "max_pool1d": "ValueError: Invalid TOSA graph", + "affine_grid": "Int64 input. Partition handling fails since arange int64 output is split between 2 partitions.", + }, ) def test_nn_functional_MI(test_data): module, inputs = test_data pipeline = TosaPipelineMI[input_t]( - module, inputs, "", use_to_edge_transform_and_lower=True + module, inputs, "", use_to_edge_transform_and_lower=False ) pipeline.pop_stage("check.aten") + pipeline.dump_artifact("to_edge") pipeline.pop_stage("check_count.exir") try: pipeline.run() diff --git a/backends/arm/test/ops/test_arange.py b/backends/arm/test/ops/test_arange.py new file mode 100644 index 00000000000..124f3ee597e --- /dev/null +++ b/backends/arm/test/ops/test_arange.py @@ -0,0 +1,135 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +input_t = tuple[torch.Tensor] +test_data_t = tuple[Callable[[], input_t], tuple[float, float, float, torch.dtype]] + + +class ArangeAdd(torch.nn.Module): + aten_op: str = "torch.ops.aten.arange.start_step" + exir_op: str = "executorch_exir_dialects_edge__ops_aten_arange_start_step" + + def __init__(self, start: float, stop: float, step: float, dtype: torch.dtype): + super().__init__() + self.args = (start, stop, step) + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.arange(*self.args, dtype=self.dtype) + x + + test_data: dict[str, test_data_t] = { + "10": (lambda: (torch.randn(10, 1),), (0.0, 10.0, 1.0, torch.float32)), + "15": (lambda: (torch.randn(10),), (0.0, 15.0, 1.5, torch.float32)), + "100": (lambda: (torch.randn(10, 1),), (0.0, 10.0, 0.1, torch.float32)), + } + + test_data_dtypes: dict[str, test_data_t] = { + "fp32_int32": (lambda: (torch.randn(10),), (0.0, 10.0, 1.0, torch.int32)), + "fp32_int64": (lambda: (torch.randn(10),), (0.0, 10.0, 1.0, torch.int64)), + "int32_int32": ( + lambda: (torch.randint(0, 10, [10], dtype=torch.int32),), + (0.0, 10.0, 1.0, torch.int32), + ), + "int32_int64": ( + lambda: (torch.randint(0, 10, [10], dtype=torch.int32),), + (0.0, 10.0, 1.0, torch.int64), + ), + } + + +@common.parametrize("test_data", ArangeAdd.test_data) +def test_arange_start_step_tosa_MI(test_data: test_data_t): + input_data, init_data = test_data + pipeline = TosaPipelineMI[input_t]( + ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op, ArangeAdd.exir_op + ) + pipeline.run() + + +@common.parametrize("test_data", ArangeAdd.test_data_dtypes) +def test_arange_start_step_dtypes_tosa_MI(test_data: test_data_t): + input_data, init_data = test_data + pipeline = TosaPipelineMI[input_t]( + ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op, ArangeAdd.exir_op + ) + pipeline.run() + + +@common.parametrize("test_data", ArangeAdd.test_data) +def test_arange_start_step_tosa_BI(test_data: test_data_t): + input_data, init_data = test_data + pipeline = TosaPipelineBI[input_t]( + ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op, ArangeAdd.exir_op + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +@common.parametrize("test_data", ArangeAdd.test_data) +def test_arange_start_step_tosa_u55(test_data: test_data_t): + input_data, init_data = test_data + pipeline = EthosU55PipelineBI[input_t]( + ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +@common.parametrize("test_data", ArangeAdd.test_data) +def test_arange_start_step_tosa_u85(test_data: test_data_t): + input_data, init_data = test_data + pipeline = EthosU85PipelineBI[input_t]( + ArangeAdd(*init_data), input_data(), ArangeAdd.aten_op + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +class LinspaceAdd(torch.nn.Module): + aten_op: str = "torch.ops.aten.linspace.default" + exir_op: str = "executorch_exir_dialects_edge__ops_aten_arange_default" + + def __init__(self, start: float, stop: float, step: int, dtype: torch.dtype): + super().__init__() + self.args = (start, stop, step) + self.dtype = dtype + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.linspace(*self.args, dtype=self.dtype) + x + + test_data: dict[str, test_data_t] = { + "10": (lambda: (torch.randn(10, 1),), (0.0, 10.0, 100, torch.float32)), + "15": (lambda: (torch.randn(20),), (0.0, 15.0, 20, torch.float32)), + } + + +@common.parametrize("test_data", LinspaceAdd.test_data) +def test_linspace_tosa_MI(test_data): + input_data, init_data = test_data + pipeline = TosaPipelineMI[input_t]( + LinspaceAdd(*init_data), input_data(), LinspaceAdd.aten_op, LinspaceAdd.exir_op + ) + pipeline.run() + + +@common.parametrize("test_data", LinspaceAdd.test_data) +def test_linspace_tosa_BI(test_data: test_data_t): + input_data, init_data = test_data + pipeline = TosaPipelineBI[input_t]( + LinspaceAdd(*init_data), input_data(), LinspaceAdd.aten_op, LinspaceAdd.exir_op + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() diff --git a/backends/arm/test/ops/test_full.py b/backends/arm/test/ops/test_full.py index 8347d01be4c..193ed632ed0 100644 --- a/backends/arm/test/ops/test_full.py +++ b/backends/arm/test/ops/test_full.py @@ -141,10 +141,6 @@ def test_const_full_tosa_MI(self): def test_full_like_tosa_MI(self, test_tensor: Tuple): self._test_full_tosa_MI_pipeline(self.FullLike(), test_tensor) - def test_const_full_nhwc_tosa_BI(self): - _input = torch.rand((2, 2, 3, 3)) * 10 - self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,)) - @parameterized.expand(AddVariableFull.test_parameters) def test_full_tosa_MI(self, test_tensor: Tuple): self._test_full_tosa_MI_pipeline( @@ -175,8 +171,6 @@ def test_full_u85_BI(self, test_tensor: Tuple): test_tensor, ) - # This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support. - @unittest.expectedFailure def test_integer_value(self): _input = torch.ones((2, 2)) integer_fill_value = 1 diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py index c6ad4420327..12d85054f79 100644 --- a/backends/arm/test/passes/test_fuse_constant_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -8,7 +8,10 @@ from typing import Tuple import torch -from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import ( + ComputeConstantOpsAOT, + FuseConstantArgsPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( PassPipeline, @@ -95,7 +98,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @common.parametrize("module", modules) -def test_fuse_batchnorm_tosa_MI(module): +def test_fuse_const_ops_tosa_MI(module): pipeline = PassPipeline[input_t]( module=module, test_data=(torch.rand(1),), @@ -103,14 +106,14 @@ def test_fuse_batchnorm_tosa_MI(module): ops_before_pass=module.ops_before_pass, ops_after_pass=module.ops_after_pass, ops_not_after_pass=module.ops_not_after_pass, - passes_with_exported_program=[FuseConstantOpsPass], + passes_with_exported_program=[ComputeConstantOpsAOT, FuseConstantArgsPass], ) pipeline.run() @unittest.skip("Test failing on internal CI") @common.parametrize("module", modules) -def test_fuse_batchnorm_tosa_BI(module): +def test_fuse_const_ops_tosa_BI(module): pipeline = TosaPipelineBI[input_t]( module, (torch.rand(10, 10),), [], [], use_to_edge_transform_and_lower=True ) diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index a53bf6fc725..48254e034e8 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -177,6 +177,8 @@ def filter_fn(node: torch.fx.Node) -> bool: ops_to_not_decompose = [ torch.ops.aten.linear.default, torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, ] + ops_to_not_decompose_if_quant_op return (ops_to_not_decompose, filter_fn) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 3028ecce923..008b0448e73 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -130,7 +130,7 @@ def quantize_value(self, x: torch.Tensor | float) -> Tensor: ).to(self.dtype) def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: - return (qx - self.zp) * self.scale + return (qx.to(torch.int64) - self.zp) * self.scale @classmethod def from_operator(cls, op, args):