Skip to content

Commit c709e67

Browse files
committed
Arm backend: Support conditional operator
- Add partition check to make sure that the submodules with the if/else codepaths are fully delegated. - Fix some partitioning issues with submodule nodes, since they point to a submodule rather than a tensor they dont have a fake tensor. - Add node visitor. - Add tests. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I00dbfdedb04c686ce04b4fb1d682816038b7e1bf
1 parent 57773ff commit c709e67

File tree

11 files changed

+446
-19
lines changed

11 files changed

+446
-19
lines changed

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,25 @@
3131
from torch.export.graph_signature import InputKind
3232

3333

34+
def is_submodule_node(node: torch.fx.Node):
35+
if node.op not in ("get_attr", "placeholder"):
36+
return False
37+
try:
38+
node.graph.owning_module.get_submodule(node.target)
39+
except AttributeError:
40+
return False
41+
return True
42+
43+
3444
def is_get_attr_node(node: torch.fx.Node) -> bool:
3545
"""
36-
Returns true if the given node is a get attr node for a tensor of the model
46+
Returns true if the given node is a get attr node for a tensor of the model.
3747
"""
38-
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
48+
return (
49+
isinstance(node, torch.fx.Node)
50+
and node.op == "get_attr"
51+
and not is_submodule_node(node)
52+
)
3953

4054

4155
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:

backends/arm/_passes/cast_int64_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule):
4141
for node in graph_module.graph.nodes:
4242
if len(node.users) == 0:
4343
continue
44+
if "val" not in node.meta:
45+
continue
4446
fake_tensor = node.meta["val"]
4547
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
4648
continue

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def remove_dim_order_kwargs(
299299

300300
def call(self, graph_module: torch.fx.GraphModule):
301301
for node in graph_module.graph.nodes:
302+
if "val" not in node.meta:
303+
continue
302304
node_data = get_first_fake_tensor(node).data
303305

304306
self.remove_dim_order_kwargs(graph_module, node)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 139 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import itertools
88
import operator
99
import typing
10-
from typing import final, Optional, Sequence, Type
10+
from typing import cast, final, Optional, Sequence, Type
1111

1212
import torch
1313
import torch.fx as fx
1414

15-
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
15+
from executorch.backends.arm._passes.arm_pass_utils import (
16+
get_first_fake_tensor,
17+
is_submodule_node,
18+
)
1619
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1720
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
1821
FuseQuantizedActivationPass,
@@ -31,6 +34,7 @@
3134
TOSA_PRO_INT_SupportList,
3235
)
3336
from executorch.backends.arm.tosa import TosaSpecification
37+
from executorch.backends.arm.tosa.specification import Tosa_1_00
3438
from executorch.exir import ExportedProgram
3539
from executorch.exir.backend.utils import WhyNoPartitionReporter
3640
from executorch.exir.dialects._ops import ops as exir_ops
@@ -110,7 +114,9 @@ def tosa_support_factory(
110114
Additional checks can be supplied to avoid partitioning additional nodes.
111115
"""
112116
# Postive checks: Add nodes to partitioning
113-
positive_checks: list[OperatorSupportBase] = []
117+
positive_checks: list[OperatorSupportBase] = [
118+
CondSupported(exported_program, tosa_spec, reporter)
119+
]
114120

115121
if tosa_spec.support_integer():
116122
positive_checks.append(TOSAProINTSupportList())
@@ -350,7 +356,8 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
350356
def is_node_supported(
351357
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
352358
) -> bool:
353-
359+
if is_submodule_node(node):
360+
return True
354361
vals = node.meta["val"]
355362
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]
356363

@@ -390,7 +397,11 @@ def is_node_supported(
390397

391398
# Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned.
392399
# If it is not partitioned, the partition will get an int64 input and fail.
393-
for input_node in node.all_input_nodes:
400+
for input_node in (
401+
input_node
402+
for input_node in node.all_input_nodes
403+
if input_node.op != "get_attr"
404+
):
394405
tensor_in = get_first_fake_tensor(input_node)
395406
if tensor_in.dtype != torch.int64:
396407
continue
@@ -426,8 +437,13 @@ def __init__(
426437
def is_node_supported(
427438
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
428439
) -> bool:
429-
430-
for input_node in node.all_input_nodes:
440+
if is_submodule_node(node):
441+
return True
442+
for input_node in (
443+
input_node
444+
for input_node in node.all_input_nodes
445+
if input_node.op != "get_attr"
446+
):
431447
tensor = get_first_fake_tensor(input_node)
432448
if tensor.dtype == torch.float64:
433449
self.reporter.report_reject(
@@ -449,7 +465,13 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int):
449465
def is_node_supported(
450466
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
451467
) -> bool:
452-
input_nodes = node.all_input_nodes
468+
if is_submodule_node(node):
469+
return True
470+
input_nodes = (
471+
input_node
472+
for input_node in node.all_input_nodes
473+
if input_node.op != "get_attr"
474+
)
453475
# check if any input node has an unsupported rank
454476
for input_node in input_nodes:
455477
input_node_shape = get_first_fake_tensor(input_node).shape
@@ -484,3 +506,112 @@ def is_node_supported(
484506
)
485507
return False
486508
return True
509+
510+
511+
class CondSupported(OperatorSupportBase):
512+
"""Checks whether the cond operator, and it's submodule args, should be partitioned."""
513+
514+
def __init__(
515+
self,
516+
exported_program: ExportedProgram,
517+
tosa_spec: TosaSpecification,
518+
reporter: WhyNoPartitionReporter,
519+
):
520+
self.exported_program = exported_program
521+
self.reporter = reporter
522+
self.tosa_spec = tosa_spec
523+
super().__init__()
524+
525+
def _fully_partitioned(self, submodule: fx.GraphModule) -> bool:
526+
partition_tag = None
527+
for submodule_node in submodule.graph.nodes:
528+
if submodule_node.op == "call_function":
529+
# Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported.
530+
if (
531+
submodule_node.target in Q_OPS
532+
and list(submodule_node.all_input_nodes)[0].op == "placeholder"
533+
):
534+
continue
535+
if (
536+
submodule_node.target in DQ_OPS
537+
and list(submodule_node.users)[0].op == "output"
538+
):
539+
continue
540+
if "delegation_tag" not in submodule_node.meta:
541+
return False
542+
if partition_tag is None:
543+
partition_tag = submodule_node.meta["delegation_tag"]
544+
elif submodule_node.meta["delegation_tag"] != partition_tag:
545+
return False
546+
return True
547+
548+
def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool:
549+
"""Returns whether the submodule arguments to a cond node were fully partitioned.
550+
Updates "val" meta of the submodules if they are.
551+
"""
552+
cond_submodules = (
553+
(
554+
self.exported_program.graph_module.get_submodule(
555+
str(cast(torch.fx.Node, submodule_node).target)
556+
),
557+
cast(torch.fx.Node, submodule_node),
558+
)
559+
for submodule_node in node.args[1:3]
560+
)
561+
for submodule, submodule_node in cond_submodules:
562+
submodule = cast(torch.fx.GraphModule, submodule)
563+
564+
if self._fully_partitioned(submodule):
565+
submodule_node.meta["val"] = submodule.graph.output_node().meta["val"]
566+
else:
567+
return False
568+
return True
569+
570+
def is_node_supported( # noqa: C901
571+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
572+
) -> bool:
573+
if is_submodule_node(node):
574+
if not isinstance(self.tosa_spec, Tosa_1_00):
575+
self.reporter.report_reject(
576+
node, "Control flow extension not supported for TOSA version <1.0"
577+
)
578+
return False
579+
if not self.tosa_spec.support_extension("cf"):
580+
self.reporter.report_reject(
581+
node,
582+
f"TOSA spec {self.tosa_spec} does not support control flow extension.",
583+
)
584+
return False
585+
for user in node.users:
586+
if user.target != torch.ops.higher_order.cond:
587+
self.reporter.report_reject(
588+
node, f"Submodule had unsupported user {user}"
589+
)
590+
return False
591+
if not self._cond_submodules_fully_partitioned(user):
592+
self.reporter.report_reject(
593+
node, "One submodule was not fully partitioned"
594+
)
595+
return False
596+
return True
597+
if node.target == torch.ops.higher_order.cond:
598+
if not isinstance(self.tosa_spec, Tosa_1_00):
599+
self.reporter.report_reject(
600+
node, "Control flow extension not supported for TOSA version <1.0"
601+
)
602+
return False
603+
if not self.tosa_spec.support_extension("cf"):
604+
self.reporter.report_reject(
605+
node,
606+
f"TOSA spec {self.tosa_spec} does not support control flow extension.",
607+
)
608+
return False
609+
610+
if not self._cond_submodules_fully_partitioned(node):
611+
self.reporter.report_reject(
612+
node, "Submodule was not fully partitioned."
613+
)
614+
return False
615+
return True
616+
617+
return False

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
op_cat,
1717
op_ceil,
1818
op_clamp,
19+
op_cond_if,
1920
op_constant_pad_nd,
2021
op_cos,
2122
op_eq,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import Any, cast, List
8+
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
validate_num_inputs,
17+
validate_valid_dtype,
18+
)
19+
from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore
20+
from executorch.backends.arm.tosa.specification import Tosa_1_00
21+
from torch.fx import Node
22+
23+
24+
@register_node_visitor
25+
class CondVisitor(NodeVisitor):
26+
target = "cond"
27+
28+
tosa_specs = NodeVisitor.tosa_specs
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: Any,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
38+
validate_num_inputs(self.target, inputs, 4)
39+
validate_valid_dtype(self.target, [inputs[0]], ts.DType.BOOL, self.tosa_spec)
40+
if not isinstance(self.tosa_spec, Tosa_1_00):
41+
raise ValueError("Trying to lower cond, but TOSA version is <1.0.")
42+
if not self.tosa_spec.support_extension("cf"):
43+
raise ValueError(
44+
f"Trying to lower cond, but TOSA specification {self.tosa_spec} does not support the cf extension."
45+
)
46+
47+
attr = ts.TosaSerializerAttribute()
48+
if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3])
49+
attr.CondIfAttribute(if_graph, else_graph)
50+
51+
self._serialize_operator(
52+
node,
53+
tosa_graph,
54+
ts.Op.COND_IF,
55+
[
56+
inputs[0].name,
57+
*(subgraph_input.name for subgraph_input in inputs[-1].special),
58+
],
59+
[output.name],
60+
attr,
61+
)

backends/arm/operators/ops_identity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def define_node(
4040
inputs: List[TosaArg],
4141
output: TosaArg,
4242
) -> None:
43-
validate_num_inputs(self.target, inputs, 1)
44-
validate_same_dtype(self.target, [*inputs, output], ts)
43+
validate_num_inputs(self.target, inputs, [1, 2])
44+
validate_same_dtype(self.target, [inputs[0], output], ts)
4545

4646
# Simply add an identityOp
4747
attr = ts.TosaSerializerAttribute()

backends/arm/scripts/parse_test_names.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here.
88
CUSTOM_EDGE_OPS = [
99
"linspace.default",
10+
"cond.default",
1011
"eye.default",
1112
"expm1.default",
1213
"vector_norm.default",

0 commit comments

Comments
 (0)