From 8f281a5c8395be1a2d88f63e8c9c6bb53ba25cdb Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Tue, 17 Jun 2025 13:12:30 +0100 Subject: [PATCH] Arm backend: Prevent illegal fusion in FuseEqualPlaceholdersPass - Constant placeholders with same values but different data types, such as int32 and fp32, shouldn't be fused into a single placeholder. Otherwise, some operators will have operands with mismatched dtypes. - Fix the bug by adding a dtype check to fuse only constants with matching types and same values. Change-Id: Ia4668964f09010ac9416fc8c109549b7e989f724 Signed-off-by: Yufeng Shi --- .../_passes/fuse_equal_placeholders_pass.py | 6 ++- .../test_fuse_equal_placeholders_ops_pass.py | 45 ++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index cd8cce1b3ea..664a0f8ea6c 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -49,7 +49,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if tensor2 is None: continue - if torch.equal(tensor1, tensor2): + if ( + tensor1.dtype == tensor2.dtype + and tensor1.shape == tensor2.shape + and torch.allclose(tensor1, tensor2, atol=1e-08) + ): eq_nodes.append(node2) if len(eq_nodes) > 1: diff --git a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py index 49626eefb71..9a26157ed7e 100644 --- a/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py +++ b/backends/arm/test/passes/test_fuse_equal_placeholders_ops_pass.py @@ -10,7 +10,10 @@ from executorch.backends.arm._passes.fuse_equal_placeholders_pass import ( FuseEqualPlaceholdersPass, ) -from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.test.tester.test_pipeline import ( + PassPipeline, + TosaPipelineMI, +) input_t = Tuple[torch.Tensor] # Input x @@ -54,6 +57,25 @@ def forward(self, x): return self.fc1(x) + self.fc2(x) +class NotFuseTensorWithDifferentType(torch.nn.Module): + + ops_before_pass = {} + ops_after_pass = {} + ops_not_after_pass = [] + + def forward(self, x: torch.Tensor, y: torch.Tensor): + """ + Args: + x: A float tensor (dtype=torch.float32) + y: An int tensor (dtype=torch.int32) + """ + a = torch.tensor(1.0, dtype=torch.float32) + b = torch.tensor(1, dtype=torch.int32) + m = x < a + n = y > b + return m, n + + def test_fuse_equal_placeholders_constants_tosa_MI(): module = FuseWeightsConstants() data = (torch.rand(1, 2, 8),) @@ -94,3 +116,24 @@ def test_fuse_equal_placeholders_state_dict_tosa_MI(): assert len(state_dict_keys) == 2, "FuseEqualPlaceholders state_dict failed" assert "_common" in state_dict_keys[0], "FuseEqualPlaceholders state_dict failed" assert "_common" in state_dict_keys[1], "FuseEqualPlaceholders state_dict failed" + + +def test_not_fuse_tensor_with_different_type_MI(): + module = NotFuseTensorWithDifferentType() + data = ( + torch.rand( + 1, + ), + torch.randint( + 0, + 10, + (1,), + dtype=torch.int, + ), + ) + pipeline = TosaPipelineMI[input_t]( + module, + data, + aten_op=[], + ) + pipeline.run()