Skip to content

Commit 0fc70cb

Browse files
abhinaykukkadapuhinriksnaer
authored andcommitted
Fix bug in sub op to ignore alpha != 1 (fixes: pytorch#11684) (pytorch#11796)
### Summary Sub node of XNNPack backend doesn't consider alpha value, this fixes the bug by falling back to portable ops and avoid partitioning the node. (fixes: pytorch#11684) ### Test plan ``` $ python -m unittest backends/xnnpack/test/ops/test_sub.py Ran 3 tests in 7.262s OK ``` ### Model run ``` import torch from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge_transform_and_lower from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer class Model(torch.nn.Module): def forward(self, x, y): return torch.sub(x, y, alpha=10) inputs = ( torch.randn(10), torch.randn(10), ) model = Model() ep = torch.export.export(model, inputs) lowered = to_edge_transform_and_lower( ep, partitioner=[XnnpackPartitioner()], ).to_executorch() et_model = _load_for_executorch_from_buffer(lowered.buffer) eager_output = model(*inputs) et_output = et_model([*inputs])[0] tolerance=1e-5 if torch.allclose(eager_output, et_output, atol=tolerance): print("Outputs are within the tolerance level.") else: print("Outputs differ beyond the tolerance level.") ``` ### output ``` Outputs are within the tolerance level. ```
1 parent fb05eeb commit 0fc70cb

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

backends/xnnpack/partition/config/generic_node_configs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
from typing import cast, List, Optional
1111

12+
import numpy as np
1213
import torch
1314
from executorch.backends.xnnpack.partition.config.xnnpack_config import (
1415
ConfigPrecisionType,
@@ -523,6 +524,17 @@ class SubConfig(GenericNodePartitionerConfig):
523524
def supported_precision_types(self) -> List[ConfigPrecisionType]:
524525
return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT]
525526

527+
def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
528+
if not self.check_common_constraints(node, ep):
529+
return False
530+
# No support for sub nodes with alpha != 1
531+
if "alpha" in node.kwargs and not np.isclose(
532+
node.kwargs["alpha"], 1.0, atol=1e-9, rtol=1e-9
533+
):
534+
why(node, reason="Sub node doesn't support alpha != 1")
535+
return False
536+
return True
537+
526538

527539
class BMMConfig(GenericNodePartitionerConfig):
528540
"""

backends/xnnpack/test/ops/test_sub.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,27 @@ def forward(self, x, y):
152152
.serialize()
153153
.run_method_and_compare_outputs()
154154
)
155+
156+
class SubWithAlpha(torch.nn.Module):
157+
def forward(self, x, y):
158+
# node with alpha = 1.0 will be partitioned
159+
out1 = torch.sub(x, y, alpha=1)
160+
# node with alpha != 1.0 will not be partitioned
161+
out2 = torch.sub(x, y, alpha=2)
162+
return out1, out2
163+
164+
def test_add_with_alpha(self):
165+
inputs = (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 4))
166+
(
167+
Tester(self.SubWithAlpha(), inputs)
168+
.export()
169+
.check_count({"torch.ops.aten.sub.Tensor": 2})
170+
.to_edge_transform_and_lower()
171+
# unpartitioned node
172+
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1})
173+
# partitioned node
174+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
175+
.to_executorch()
176+
.serialize()
177+
.run_method_and_compare_outputs()
178+
)

0 commit comments

Comments
 (0)