From a01264a2db117dc78301b0a499114d13f333c71d Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Wed, 20 Nov 2024 14:29:32 +0100 Subject: [PATCH] Fix bug in ScalarsToAttributePass The pass should not modify the scalar argument if output is non-float. Signed-off-by: Oscar Andersson Change-Id: I36f6975e8d6f33e5834e44959f6e426808452de1 --- backends/arm/_passes/cast_int64_pass.py | 43 ++++++++++++++----- .../arm/_passes/scalars_to_attribute_pass.py | 5 +++ backends/arm/test/ops/test_scalars.py | 9 ++++ 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/backends/arm/_passes/cast_int64_pass.py b/backends/arm/_passes/cast_int64_pass.py index a9952edec3c..aab6ed8eb42 100644 --- a/backends/arm/_passes/cast_int64_pass.py +++ b/backends/arm/_passes/cast_int64_pass.py @@ -5,8 +5,15 @@ # pyre-unsafe +import logging + import torch +from executorch.backends.arm._passes.arm_pass_utils import is_param_node from executorch.exir.pass_base import ExportPass, PassResult +from torch._export.utils import is_buffer + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) class CastInt64ToInt32Pass(ExportPass): @@ -18,17 +25,31 @@ def _to_int32(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: fake_tensor = node.meta["val"] if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): - if node.meta["val"].dtype == torch.int64: - node.meta["val"] = node.meta["val"].to(torch.int32) - buffer_name = ( - self.exported_program.graph_signature.inputs_to_buffers[ - node.name - ] - ) - new_tensor = self.exported_program.state_dict[buffer_name].to( - torch.int32 - ) - self.exported_program.state_dict[buffer_name] = new_tensor + if node.meta["val"].dtype == torch.int64 and is_param_node( + self.exported_program, node + ): + if is_buffer(self.exported_program, node): + node.meta["val"] = node.meta["val"].to(torch.int32) + buffer_name = ( + self.exported_program.graph_signature.inputs_to_buffers[ + node.name + ] + ) + buffer = self.exported_program.state_dict[node.name] + logger.warning( + f"Casting buffer {node.name} from torch.int64 to torch.int32" + f" defined in {node.meta['stack_trace']}" + ) + if torch.min(buffer) < torch.iinfo(torch.int32).min: + raise RuntimeError( + f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}" + ) + if torch.max(buffer) > torch.iinfo(torch.int32).max: + raise RuntimeError( + f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}" + ) + buffer_int32 = buffer.to(torch.int32) + self.exported_program.state_dict[buffer_name] = buffer_int32 def call(self, graph_module: torch.fx.GraphModule): self._to_int32(graph_module) diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index a689799ed6e..f6fe02b6ebc 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -51,6 +51,11 @@ def call(self, graph_module: GraphModule) -> PassResult: if isinstance(arg, Node): new_args.append(arg) continue + if isinstance(arg, int) and not torch.is_floating_point( + get_first_fake_tensor(n) + ): + new_args.append(arg) + continue prefix = "_tensor_constant_" get_new_attr_name = get_new_attr_name_with_prefix(prefix) diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index 86433745a63..cd3dd72f608 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -75,6 +75,12 @@ def forward(self, x): x = 1.0 + x return x + class ShiftInplaceSub(torch.nn.Module): + def forward(self, x): + x = x >> 4 + x -= 10 + return x + # Inplace ops end with '_' (from aten naming) ops = [ ("Add", Add()), @@ -160,3 +166,6 @@ def test_MI_const(self, test_name: str, op: torch.nn.Module, x): @parameterized.expand(tensor_scalar_tests) def test_BI(self, test_name: str, op: torch.nn.Module, x, y): self._test_add_tosa_BI_pipeline(op, (x, y)) + + def test_shift_sub_inplace_tosa_MI(self): + self._test_add_tosa_MI_pipeline(self.ShiftInplaceSub(), (torch.IntTensor(5),))