Skip to content

Commit 7da6e25

Browse files
committed
Fix double-tracing in SpecPropPass (#15485)
Summary: Our current SpecPropPass doesn't properly capture the effect of guards in the shape environment due to double-tracing certain ops. The problem looks like this: * Every time we trace through the graph, we generate new symints. * That's fine, since shape_env will pick up guards during the retrace. * Problem is that SpecPropPass does this twice. Once to generate the spec and then once by calling super().call_operator(...) ([https://github.com/.../exir/passes/spec_prop_pass.py...](https://github.com/pytorch/executorch/blob/11f752cf84b296a39c0b74b889d618f279bc8186/exir/passes/spec_prop_pass.py#L98)). * The tensor spec gets the symint from the first. But the graph and guards use the second. * Hence the tensor spec doesn't pick up on guards. To resolve this, I've updated the SpecPropPass to re-trace the graph and then generate specs based on the meta values, not the traced ProxyValues (thanks angelayi for the suggestion). This resolves the issue. I originally saw this issue with the NMS torchvision op, but to avoid adding a new dep to the core EXIR tests, I've written a test with a custom op that uses an unbacked symint in the meta kernel output shape to replicate the bug in the same way. Differential Revision: D85913581 Pulled By: GregoryComer
1 parent 18c1c5b commit 7da6e25

File tree

3 files changed

+198
-92
lines changed

3 files changed

+198
-92
lines changed

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ def _get_convolution_replacement(self, node) -> int:
171171
weight_permuted,
172172
)
173173

174+
quantized_multiplier_tensor = create_constant_placeholder(
175+
self.exported_program,
176+
node.graph,
177+
node.name + "_quantized_multiplier",
178+
InputKind.PARAMETER,
179+
torch.tensor(quantized_multipliers, dtype=torch.int32),
180+
)
181+
182+
quantized_shift_tensor = create_constant_placeholder(
183+
self.exported_program,
184+
node.graph,
185+
node.name + "_quantized_shift",
186+
InputKind.PARAMETER,
187+
torch.tensor(quantized_shifts, dtype=torch.int32),
188+
)
189+
174190
new_args = (
175191
x,
176192
weight_nhwc,
@@ -180,8 +196,8 @@ def _get_convolution_replacement(self, node) -> int:
180196
dilation,
181197
-input_zero_point,
182198
output_zero_point,
183-
torch.tensor(quantized_multipliers, dtype=torch.int32),
184-
torch.tensor(quantized_shifts, dtype=torch.int32),
199+
quantized_multiplier_tensor,
200+
quantized_shift_tensor,
185201
output_qmin,
186202
output_qmax,
187203
)

exir/passes/spec_prop_pass.py

Lines changed: 46 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@
66

77
# pyre-strict
88

9-
from typing import List, Optional
9+
import operator
10+
from typing import Optional
1011

1112
import torch
1213
from executorch.exir.delegate import executorch_call_delegate
13-
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
14+
from executorch.exir.pass_base import ExportPass, ProxyValue
1415
from executorch.exir.tensor import TensorSpec
1516
from torch.export.exported_program import ExportGraphSignature
1617
from torch.fx.node import Node
18+
from torch.fx.passes.infra.pass_base import PassResult
1719
from torch.utils import _pytree as pytree
1820

1921

@@ -52,12 +54,48 @@ class SpecPropPass(ExportPass):
5254
def __init__(self) -> None:
5355
super().__init__()
5456

55-
def on_attr(self, attr: ProxyValue) -> None:
56-
attr.node.meta["spec"] = pytree.tree_map_only(
57-
torch.Tensor,
58-
make_spec,
59-
attr.data,
60-
)
57+
def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
58+
# Re-trace metadata to ensure it's up to date.
59+
res = ExportPass()(graph_module)
60+
assert res is not None
61+
gm = res.graph_module
62+
63+
def get_spec(x):
64+
if hasattr(x, "meta"):
65+
return x.meta.get("spec", None)
66+
else:
67+
return None
68+
69+
for module in gm.modules():
70+
if isinstance(module, torch.fx.GraphModule):
71+
for node in module.graph.nodes:
72+
meta_val = node.meta.get("val", None)
73+
74+
if node.op == "output":
75+
node.meta["spec"] = pytree.tree_map(get_spec, node.args[0])
76+
elif node.op == "call_function" and node.target == operator.getitem:
77+
value_spec = pytree.tree_map(get_spec, node.args[0])
78+
node.meta["spec"] = value_spec[node.args[1]]
79+
elif (
80+
node.op == "call_function"
81+
and node.target == executorch_call_delegate
82+
):
83+
# Note: We currently rely on delegate node specs not being regenerated,
84+
# as the spec is set somewhat manually when adding the call delegate node.
85+
# If we regenerate, it can change and break lowering (it becomes a tuple?).
86+
# Ideally, we should figure out how to make the spec regeneration not break
87+
# things.
88+
#
89+
# We do need to regenerate non-call-delegate node specs, as this pass is called
90+
# multiple times in some lowering paths (backends can and do call it).
91+
if "spec" not in node.meta:
92+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
93+
else:
94+
node.meta["spec"] = pytree.tree_map(make_spec, meta_val)
95+
return res
96+
97+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
98+
return self(graph_module)
6199

62100
def update_placeholder_tensor_specs(
63101
self,
@@ -84,85 +122,3 @@ def update_placeholder_tensor_specs(
84122
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
85123
):
86124
spec.const = True
87-
88-
# pyre-ignore
89-
def placeholder(self, name: str, arg, meta):
90-
meta["spec"] = make_spec(arg)
91-
return super().placeholder(name, arg, meta)
92-
93-
# pyre-ignore
94-
def call_operator(self, op, args, kwargs, meta):
95-
args_data, kwargs_data = pytree.tree_map_only(
96-
ProxyValue, lambda x: x.data, (args, kwargs)
97-
)
98-
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
99-
return super().call_operator(op, args, kwargs, meta)
100-
101-
# pyre-ignore
102-
def call_getitem(self, value, key: int, meta):
103-
meta["spec"] = value.node.meta["spec"][key]
104-
return super().call_getitem(value, key, meta)
105-
106-
# pyre-ignore
107-
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
108-
# true_fn/false_fn return tensors of the same shape, so we can pick
109-
# either one here.
110-
*_, true_out_node = true_fn.graph.nodes
111-
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
112-
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
113-
114-
def call_while(
115-
self,
116-
cond_fn: torch.fx.GraphModule,
117-
body_fn: torch.fx.GraphModule,
118-
carried_inputs: List[ProxyValue],
119-
additional_inputs: List[ProxyValue],
120-
meta: NodeMetadata,
121-
):
122-
meta["spec"] = pytree.tree_map(make_spec, carried_inputs)
123-
return super().call_while(
124-
cond_fn, body_fn, carried_inputs, additional_inputs, meta
125-
)
126-
127-
def call_map(
128-
self,
129-
f: torch.fx.GraphModule,
130-
mapped_args: List[ProxyValue],
131-
operands: List[ProxyValue],
132-
meta: NodeMetadata,
133-
) -> ProxyValue:
134-
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
135-
*_, body_out_node = f.graph.nodes
136-
body_out_node_fake_tensor = body_out_node.meta["val"]
137-
map_fake_tensor = pytree.tree_map_only(
138-
torch.Tensor,
139-
lambda x: x.new_empty(mapped_dim_size, *x.shape),
140-
body_out_node_fake_tensor,
141-
)
142-
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
143-
return super().call_map(f, mapped_args, operands, meta)
144-
145-
# pyre-ignore
146-
def call_delegate(self, lowered_module, args, kwargs, meta):
147-
args_data, kwargs_data = pytree.tree_map_only(
148-
ProxyValue, lambda x: x.data, (args, kwargs)
149-
)
150-
# If spec is missing, re-genenrate it with args data
151-
if "spec" not in meta:
152-
meta["spec"] = pytree.tree_map(
153-
make_spec,
154-
executorch_call_delegate(lowered_module, *args_data),
155-
)
156-
return super().call_delegate(lowered_module, args, kwargs, meta)
157-
158-
# pyre-ignore
159-
def output(self, results, meta):
160-
# pyre-ignore
161-
def get_spec(x):
162-
if isinstance(x, ProxyValue):
163-
return x.node.meta["spec"]
164-
else:
165-
return make_spec(x)
166-
167-
meta["spec"] = pytree.tree_map(get_spec, results)
168-
return super().output(results, meta)

exir/tests/test_passes.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
7575
from executorch.exir.program._program import lift_constant_tensor_pass
7676
from executorch.exir.schema import TensorShapeDynamism
77+
from executorch.exir.sym_util import eval_upper_bound
7778
from executorch.exir.tensor import TensorSpec
7879
from executorch.exir.tests.common import register_additional_test_aten_ops
7980
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
@@ -113,6 +114,7 @@ def collect_ops(gm: torch.fx.GraphModule):
113114

114115
lib.define("foo(Tensor self) -> (Tensor, Tensor)")
115116
lib.define("add_relu(Tensor self, Tensor other) -> Tensor")
117+
lib.define("unbacked(Tensor self) -> Tensor")
116118

117119

118120
@impl(lib, "foo", "CompositeExplicitAutograd")
@@ -132,6 +134,29 @@ def foo_out(
132134
return a + 1, None
133135

134136

137+
@impl(lib, "unbacked", "CPU")
138+
def unbacked(a: torch.Tensor) -> torch.Tensor:
139+
return a[: a[0]]
140+
141+
142+
@torch.library.register_fake(f"{lib.ns}::unbacked")
143+
def meta_unbacked(x):
144+
ctx = torch._custom_ops.get_ctx()
145+
out_size = ctx.create_unbacked_symint()
146+
return x.new_empty(out_size)
147+
148+
149+
lib.define("unbacked.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)")
150+
151+
152+
@impl(lib, "unbacked.out", "CPU")
153+
def unbacked_out(
154+
x: torch.Tensor,
155+
out: torch.Tensor,
156+
) -> torch.Tensor:
157+
return out.copy_(x[: x[0]])
158+
159+
135160
def simple_promote_dtype(
136161
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
137162
) -> torch.dtype:
@@ -611,6 +636,115 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
611636

612637
self.assertEqual(counter, 1)
613638

639+
def test_spec_prop_pass_unbacked_symint(self) -> None:
640+
# Verify that the spec prop pass picks up on guards for
641+
# unbacked symints.
642+
class Unbacked(torch.nn.Module):
643+
def forward(self, x):
644+
output = torch.ops.DO_NOT_USE_TEST_ONLY.unbacked(x)
645+
torch._constrain_as_size(output.shape[0], max=10)
646+
return output
647+
648+
model = Unbacked()
649+
gm = (
650+
to_edge(export(model, (torch.LongTensor([5, 4, 3, 2, 1, 0, 1, 2]),)))
651+
.exported_program()
652+
.graph_module
653+
)
654+
new_gm = SpecPropPass()(gm)
655+
self.assertIsNotNone(new_gm)
656+
657+
# Check the spec for the custom op node. It should have a max size of 10.
658+
op_node = next(
659+
n
660+
for n in new_gm.graph_module.graph.nodes
661+
if n.target == exir_ops.edge.DO_NOT_USE_TEST_ONLY.unbacked.default
662+
)
663+
self.assertIsNotNone(op_node)
664+
665+
spec: TensorSpec = op_node.meta["spec"]
666+
self.assertEqual(len(spec.shape), 1) # Should be rank 1
667+
upper_bound = eval_upper_bound(spec.shape[0])
668+
self.assertEqual(upper_bound, 10) # Should be a concrete value
669+
670+
def test_spec_prop_pass_cond(self) -> None:
671+
class ModelWithCond(torch.nn.Module):
672+
def true_fn(self, val):
673+
return val * 2
674+
675+
def false_fn(self, val):
676+
return val + 1
677+
678+
def forward(self, x):
679+
return torch.cond(x[0] > 0, self.true_fn, self.false_fn, [x])
680+
681+
model = ModelWithCond()
682+
inputs = (torch.ones(10),)
683+
dynamic_shapes = {"x": {0: torch.export.Dim("x", min=1, max=20)}}
684+
685+
# Run the spec prop pass and sanity check the spec on the cond.
686+
edge_program = to_edge(export(model, inputs, dynamic_shapes=dynamic_shapes))
687+
gm = edge_program.exported_program().graph_module
688+
new_gm = SpecPropPass()(gm)
689+
self.assertIsNotNone(new_gm)
690+
691+
# Check the spec for the cond node. It should have a max size of 20 (matching the dynamic shape upper bound).
692+
cond_node = next(
693+
n
694+
for n in new_gm.graph_module.graph.nodes
695+
if hasattr(n.target, "name") and n.target.name() == "cond"
696+
)
697+
self.assertIsNotNone(cond_node)
698+
699+
# Spec for the cond node should be a single-element tuple
700+
spec: tuple[TensorSpec] = cond_node.meta["spec"]
701+
self.assertTrue(isinstance(spec, tuple))
702+
self.assertEqual(len(spec), 1)
703+
704+
self.assertEqual(len(spec[0].shape), 1) # Should be rank 1
705+
upper_bound = eval_upper_bound(spec[0].shape[0])
706+
self.assertEqual(upper_bound, 20) # Should match dynamic shape bound
707+
708+
def test_spec_prop_pass_while(self) -> None:
709+
class ModelWithWhile(torch.nn.Module):
710+
def forward(self, i):
711+
def loop_cond(i, acc):
712+
return i[0] > 0
713+
714+
def loop_body(i, acc):
715+
return i - 1, acc + i
716+
717+
_, acc = torch._higher_order_ops.while_loop(
718+
loop_cond, loop_body, (i, torch.zeros(10))
719+
)
720+
return acc
721+
722+
model = ModelWithWhile()
723+
inputs = (torch.Tensor([5]),)
724+
725+
# Run the spec prop pass and sanity check the spec on the while.
726+
edge_program = to_edge(export(model, inputs))
727+
gm = edge_program.exported_program().graph_module
728+
new_gm = SpecPropPass()(gm)
729+
self.assertIsNotNone(new_gm)
730+
731+
# Check the spec for the while node. It should have a max size of 10 (matching the torch.zeros(10) in the model).
732+
while_node = next(
733+
n
734+
for n in new_gm.graph_module.graph.nodes
735+
if hasattr(n.target, "name") and n.target.name() == "while_loop"
736+
)
737+
self.assertIsNotNone(while_node)
738+
739+
# Spec for the while node should be a two-element tuple
740+
spec: tuple[TensorSpec] = while_node.meta["spec"]
741+
self.assertTrue(isinstance(spec, tuple))
742+
self.assertEqual(len(spec), 2)
743+
744+
self.assertEqual(len(spec[1].shape), 1) # Should be rank 1
745+
upper_bound = eval_upper_bound(spec[1].shape[0])
746+
self.assertEqual(upper_bound, 10)
747+
614748
def test_compile_fix_broken_ops(self) -> None:
615749
class ExportableLoop(nn.Module):
616750
def __init__(self, hidden_size, out_channels):

0 commit comments

Comments
 (0)