Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions backends/cortex_m/passes/convert_to_cortex_m_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,22 @@ def _get_convolution_replacement(self, node) -> int:
weight_permuted,
)

quantized_multiplier_tensor = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_quantized_multiplier",
InputKind.PARAMETER,
torch.tensor(quantized_multipliers, dtype=torch.int32),
)

quantized_shift_tensor = create_constant_placeholder(
self.exported_program,
node.graph,
node.name + "_quantized_shift",
InputKind.PARAMETER,
torch.tensor(quantized_shifts, dtype=torch.int32),
)

new_args = (
x,
weight_nhwc,
Expand All @@ -180,8 +196,8 @@ def _get_convolution_replacement(self, node) -> int:
dilation,
-input_zero_point,
output_zero_point,
torch.tensor(quantized_multipliers, dtype=torch.int32),
torch.tensor(quantized_shifts, dtype=torch.int32),
quantized_multiplier_tensor,
quantized_shift_tensor,
output_qmin,
output_qmax,
)
Expand Down
136 changes: 46 additions & 90 deletions exir/passes/spec_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

# pyre-strict

from typing import List, Optional
import operator
from typing import Optional

import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.tensor import TensorSpec
from torch.export.exported_program import ExportGraphSignature
from torch.fx.node import Node
from torch.fx.passes.infra.pass_base import PassResult
from torch.utils import _pytree as pytree


Expand Down Expand Up @@ -52,12 +54,48 @@ class SpecPropPass(ExportPass):
def __init__(self) -> None:
super().__init__()

def on_attr(self, attr: ProxyValue) -> None:
attr.node.meta["spec"] = pytree.tree_map_only(
torch.Tensor,
make_spec,
attr.data,
)
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Re-trace metadata to ensure it's up to date.
res = ExportPass()(graph_module)
assert res is not None
gm = res.graph_module

def get_spec(x):
if hasattr(x, "meta"):
return x.meta.get("spec", None)
else:
return None

for module in gm.modules():
if isinstance(module, torch.fx.GraphModule):
for node in module.graph.nodes:
meta_val = node.meta.get("val", None)

if node.op == "output":
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
elif node.op == "call_function" and node.target == operator.getitem:
value_spec = pytree.tree_map(get_spec, node.args[0])
node.meta["spec"] = value_spec[node.args[1]]
elif (
node.op == "call_function"
and node.target == executorch_call_delegate
):
# Note: We currently rely on delegate node specs not being regenerated,
# as the spec is set somewhat manually when adding the call delegate node.
# If we regenerate, it can change and break lowering (it becomes a tuple?).
# Ideally, we should figure out how to make the spec regeneration not break
# things.
#
# We do need to regenerate non-call-delegate node specs, as this pass is called
# multiple times in some lowering paths (backends can and do call it).
if "spec" not in node.meta:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
else:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
Comment on lines +91 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider merging these 2 conditions?

Copy link
Member Author

@GregoryComer GregoryComer Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some weird existing behavior here that seems to need to be preserved (barring a larger update). Basically, we don't want to regenerate call_delegate node specs but do want to regenerate everything else. I'll add a comment detailing why.

Comment on lines +79 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant something like:

Suggested change
elif (
node.op == "call_function"
and node.target == executorch_call_delegate
):
# Note: We currently rely on delegate node specs not being regenerated,
# as the spec is set somewhat manually when adding the call delegate node.
# If we regenerate, it can change and break lowering (it becomes a tuple?).
# Ideally, we should figure out how to make the spec regeneration not break
# things.
#
# We do need to regenerate non-call-delegate node specs, as this pass is called
# multiple times in some lowering paths (backends can and do call it).
if "spec" not in node.meta:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
else:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
else:
if "spec" not in node.meta:
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - the issue that I've seen is that sometimes this pass gets called multiple times (Cadence backend does this, for example) and thus we need to regenerate the spec for most nodes to make sure they pick up on any shape changes between calls.

But if we regenerate the spec for call_delegate nodes, it breaks things. So the if "spec" not in node.meta: condition should only apply to call_delegate but not anything else. Otherwise is breaks existing backend assumptions.

Ideally, we'll do a deeper change to fix this but this preserves the existing behavior. I could change the line to if "spec" not in node.meta or node.target != executorch_call_delegate if you'd prefer.

return res

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return self(graph_module)

def update_placeholder_tensor_specs(
self,
Expand All @@ -84,85 +122,3 @@ def update_placeholder_tensor_specs(
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
):
spec.const = True

# pyre-ignore
def placeholder(self, name: str, arg, meta):
meta["spec"] = make_spec(arg)
return super().placeholder(name, arg, meta)

# pyre-ignore
def call_operator(self, op, args, kwargs, meta):
args_data, kwargs_data = pytree.tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
return super().call_operator(op, args, kwargs, meta)

# pyre-ignore
def call_getitem(self, value, key: int, meta):
meta["spec"] = value.node.meta["spec"][key]
return super().call_getitem(value, key, meta)

# pyre-ignore
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
# true_fn/false_fn return tensors of the same shape, so we can pick
# either one here.
*_, true_out_node = true_fn.graph.nodes
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
return super().call_cond(pred, true_fn, false_fn, inputs, meta)

def call_while(
self,
cond_fn: torch.fx.GraphModule,
body_fn: torch.fx.GraphModule,
carried_inputs: List[ProxyValue],
additional_inputs: List[ProxyValue],
meta: NodeMetadata,
):
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
return super().call_while(
cond_fn, body_fn, carried_inputs, additional_inputs, meta
)
Comment on lines -107 to -125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we don't have to handle condition and while anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should be handled by having the tracing logic use ExportPass to regenerate the meta values and then assigning spec values for each node correspondingly. I did go ahead and specific tests for cond and while to verify that the specs are generated correctly. As long as the cond + while outputs don't alias anything else (my understanding is that this should be the case), it should be good.


def call_map(
self,
f: torch.fx.GraphModule,
mapped_args: List[ProxyValue],
operands: List[ProxyValue],
meta: NodeMetadata,
) -> ProxyValue:
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
*_, body_out_node = f.graph.nodes
body_out_node_fake_tensor = body_out_node.meta["val"]
map_fake_tensor = pytree.tree_map_only(
torch.Tensor,
lambda x: x.new_empty(mapped_dim_size, *x.shape),
body_out_node_fake_tensor,
)
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
return super().call_map(f, mapped_args, operands, meta)

# pyre-ignore
def call_delegate(self, lowered_module, args, kwargs, meta):
args_data, kwargs_data = pytree.tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
# If spec is missing, re-genenrate it with args data
if "spec" not in meta:
meta["spec"] = pytree.tree_map(
make_spec,
executorch_call_delegate(lowered_module, *args_data),
)
return super().call_delegate(lowered_module, args, kwargs, meta)

# pyre-ignore
def output(self, results, meta):
# pyre-ignore
def get_spec(x):
if isinstance(x, ProxyValue):
return x.node.meta["spec"]
else:
return make_spec(x)

meta["spec"] = pytree.tree_map(get_spec, results)
return super().output(results, meta)
134 changes: 134 additions & 0 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
from executorch.exir.program._program import lift_constant_tensor_pass
from executorch.exir.schema import TensorShapeDynamism
from executorch.exir.sym_util import eval_upper_bound
from executorch.exir.tensor import TensorSpec
from executorch.exir.tests.common import register_additional_test_aten_ops
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
Expand Down Expand Up @@ -113,6 +114,7 @@ def collect_ops(gm: torch.fx.GraphModule):

lib.define("foo(Tensor self) -> (Tensor, Tensor)")
lib.define("add_relu(Tensor self, Tensor other) -> Tensor")
lib.define("unbacked(Tensor self) -> Tensor")


@impl(lib, "foo", "CompositeExplicitAutograd")
Expand All @@ -132,6 +134,29 @@ def foo_out(
return a + 1, None


@impl(lib, "unbacked", "CPU")
def unbacked(a: torch.Tensor) -> torch.Tensor:
return a[: a[0]]


@torch.library.register_fake(f"{lib.ns}::unbacked")
def meta_unbacked(x):
ctx = torch._custom_ops.get_ctx()
out_size = ctx.create_unbacked_symint()
return x.new_empty(out_size)


lib.define("unbacked.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")


@impl(lib, "unbacked.out", "CPU")
def unbacked_out(
x: torch.Tensor,
out: torch.Tensor,
) -> torch.Tensor:
return out.copy_(x[: x[0]])


def simple_promote_dtype(
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
) -> torch.dtype:
Expand Down Expand Up @@ -611,6 +636,115 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:

self.assertEqual(counter, 1)

def test_spec_prop_pass_unbacked_symint(self) -> None:
# Verify that the spec prop pass picks up on guards for
# unbacked symints.
class Unbacked(torch.nn.Module):
def forward(self, x):
output = torch.ops.DO_NOT_USE_TEST_ONLY.unbacked(x)
torch._constrain_as_size(output.shape[0], max=10)
return output

model = Unbacked()
gm = (
to_edge(export(model, (torch.LongTensor([5, 4, 3, 2, 1, 0, 1, 2]),)))
.exported_program()
.graph_module
)
new_gm = SpecPropPass()(gm)
self.assertIsNotNone(new_gm)

# Check the spec for the custom op node. It should have a max size of 10.
op_node = next(
n
for n in new_gm.graph_module.graph.nodes
if n.target == exir_ops.edge.DO_NOT_USE_TEST_ONLY.unbacked.default
)
self.assertIsNotNone(op_node)

spec: TensorSpec = op_node.meta["spec"]
self.assertEqual(len(spec.shape), 1) # Should be rank 1
upper_bound = eval_upper_bound(spec.shape[0])
self.assertEqual(upper_bound, 10) # Should be a concrete value

def test_spec_prop_pass_cond(self) -> None:
class ModelWithCond(torch.nn.Module):
def true_fn(self, val):
return val * 2

def false_fn(self, val):
return val + 1

def forward(self, x):
return torch.cond(x[0] > 0, self.true_fn, self.false_fn, [x])

model = ModelWithCond()
inputs = (torch.ones(10),)
dynamic_shapes = {"x": {0: torch.export.Dim("x", min=1, max=20)}}

# Run the spec prop pass and sanity check the spec on the cond.
edge_program = to_edge(export(model, inputs, dynamic_shapes=dynamic_shapes))
gm = edge_program.exported_program().graph_module
new_gm = SpecPropPass()(gm)
self.assertIsNotNone(new_gm)

# Check the spec for the cond node. It should have a max size of 20 (matching the dynamic shape upper bound).
cond_node = next(
n
for n in new_gm.graph_module.graph.nodes
if hasattr(n.target, "name") and n.target.name() == "cond"
)
self.assertIsNotNone(cond_node)

# Spec for the cond node should be a single-element tuple
spec: tuple[TensorSpec] = cond_node.meta["spec"]
self.assertTrue(isinstance(spec, tuple))
self.assertEqual(len(spec), 1)

self.assertEqual(len(spec[0].shape), 1) # Should be rank 1
upper_bound = eval_upper_bound(spec[0].shape[0])
self.assertEqual(upper_bound, 20) # Should match dynamic shape bound

def test_spec_prop_pass_while(self) -> None:
class ModelWithWhile(torch.nn.Module):
def forward(self, i):
def loop_cond(i, acc):
return i[0] > 0

def loop_body(i, acc):
return i - 1, acc + i

_, acc = torch._higher_order_ops.while_loop(
loop_cond, loop_body, (i, torch.zeros(10))
)
return acc

model = ModelWithWhile()
inputs = (torch.Tensor([5]),)

# Run the spec prop pass and sanity check the spec on the while.
edge_program = to_edge(export(model, inputs))
gm = edge_program.exported_program().graph_module
new_gm = SpecPropPass()(gm)
self.assertIsNotNone(new_gm)

# Check the spec for the while node. It should have a max size of 10 (matching the torch.zeros(10) in the model).
while_node = next(
n
for n in new_gm.graph_module.graph.nodes
if hasattr(n.target, "name") and n.target.name() == "while_loop"
)
self.assertIsNotNone(while_node)

# Spec for the while node should be a two-element tuple
spec: tuple[TensorSpec] = while_node.meta["spec"]
self.assertTrue(isinstance(spec, tuple))
self.assertEqual(len(spec), 2)

self.assertEqual(len(spec[1].shape), 1) # Should be rank 1
upper_bound = eval_upper_bound(spec[1].shape[0])
self.assertEqual(upper_bound, 10)

def test_compile_fix_broken_ops(self) -> None:
class ExportableLoop(nn.Module):
def __init__(self, hidden_size, out_channels):
Expand Down
Loading