diff --git a/backends/cortex_m/passes/convert_to_cortex_m_pass.py b/backends/cortex_m/passes/convert_to_cortex_m_pass.py index 5924f65cdf0..5a142efd639 100644 --- a/backends/cortex_m/passes/convert_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/convert_to_cortex_m_pass.py @@ -171,6 +171,22 @@ def _get_convolution_replacement(self, node) -> int: weight_permuted, ) + quantized_multiplier_tensor = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_quantized_multiplier", + InputKind.PARAMETER, + torch.tensor(quantized_multipliers, dtype=torch.int32), + ) + + quantized_shift_tensor = create_constant_placeholder( + self.exported_program, + node.graph, + node.name + "_quantized_shift", + InputKind.PARAMETER, + torch.tensor(quantized_shifts, dtype=torch.int32), + ) + new_args = ( x, weight_nhwc, @@ -180,8 +196,8 @@ def _get_convolution_replacement(self, node) -> int: dilation, -input_zero_point, output_zero_point, - torch.tensor(quantized_multipliers, dtype=torch.int32), - torch.tensor(quantized_shifts, dtype=torch.int32), + quantized_multiplier_tensor, + quantized_shift_tensor, output_qmin, output_qmax, ) diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index ab5367d1b20..637cc0013f0 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -6,14 +6,16 @@ # pyre-strict -from typing import List, Optional +import operator +from typing import Optional import torch from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue +from executorch.exir.pass_base import ExportPass, ProxyValue from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature from torch.fx.node import Node +from torch.fx.passes.infra.pass_base import PassResult from torch.utils import _pytree as pytree @@ -52,12 +54,48 @@ class SpecPropPass(ExportPass): def __init__(self) -> None: super().__init__() - def on_attr(self, attr: ProxyValue) -> None: - attr.node.meta["spec"] = pytree.tree_map_only( - torch.Tensor, - make_spec, - attr.data, - ) + def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Re-trace metadata to ensure it's up to date. + res = ExportPass()(graph_module) + assert res is not None + gm = res.graph_module + + def get_spec(x): + if hasattr(x, "meta"): + return x.meta.get("spec", None) + else: + return None + + for module in gm.modules(): + if isinstance(module, torch.fx.GraphModule): + for node in module.graph.nodes: + meta_val = node.meta.get("val", None) + + if node.op == "output": + node.meta["spec"] = pytree.tree_map(get_spec, node.args[0]) + elif node.op == "call_function" and node.target == operator.getitem: + value_spec = pytree.tree_map(get_spec, node.args[0]) + node.meta["spec"] = value_spec[node.args[1]] + elif ( + node.op == "call_function" + and node.target == executorch_call_delegate + ): + # Note: We currently rely on delegate node specs not being regenerated, + # as the spec is set somewhat manually when adding the call delegate node. + # If we regenerate, it can change and break lowering (it becomes a tuple?). + # Ideally, we should figure out how to make the spec regeneration not break + # things. + # + # We do need to regenerate non-call-delegate node specs, as this pass is called + # multiple times in some lowering paths (backends can and do call it). + if "spec" not in node.meta: + node.meta["spec"] = pytree.tree_map(make_spec, meta_val) + else: + node.meta["spec"] = pytree.tree_map(make_spec, meta_val) + return res + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + return self(graph_module) def update_placeholder_tensor_specs( self, @@ -84,85 +122,3 @@ def update_placeholder_tensor_specs( in exported_program.graph_signature.inputs_to_lifted_tensor_constants ): spec.const = True - - # pyre-ignore - def placeholder(self, name: str, arg, meta): - meta["spec"] = make_spec(arg) - return super().placeholder(name, arg, meta) - - # pyre-ignore - def call_operator(self, op, args, kwargs, meta): - args_data, kwargs_data = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data)) - return super().call_operator(op, args, kwargs, meta) - - # pyre-ignore - def call_getitem(self, value, key: int, meta): - meta["spec"] = value.node.meta["spec"][key] - return super().call_getitem(value, key, meta) - - # pyre-ignore - def call_cond(self, pred, true_fn, false_fn, inputs, meta): - # true_fn/false_fn return tensors of the same shape, so we can pick - # either one here. - *_, true_out_node = true_fn.graph.nodes - meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) - return super().call_cond(pred, true_fn, false_fn, inputs, meta) - - def call_while( - self, - cond_fn: torch.fx.GraphModule, - body_fn: torch.fx.GraphModule, - carried_inputs: List[ProxyValue], - additional_inputs: List[ProxyValue], - meta: NodeMetadata, - ): - meta["spec"] = pytree.tree_map(make_spec, carried_inputs) - return super().call_while( - cond_fn, body_fn, carried_inputs, additional_inputs, meta - ) - - def call_map( - self, - f: torch.fx.GraphModule, - mapped_args: List[ProxyValue], - operands: List[ProxyValue], - meta: NodeMetadata, - ) -> ProxyValue: - mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) - *_, body_out_node = f.graph.nodes - body_out_node_fake_tensor = body_out_node.meta["val"] - map_fake_tensor = pytree.tree_map_only( - torch.Tensor, - lambda x: x.new_empty(mapped_dim_size, *x.shape), - body_out_node_fake_tensor, - ) - meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) - return super().call_map(f, mapped_args, operands, meta) - - # pyre-ignore - def call_delegate(self, lowered_module, args, kwargs, meta): - args_data, kwargs_data = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - # If spec is missing, re-genenrate it with args data - if "spec" not in meta: - meta["spec"] = pytree.tree_map( - make_spec, - executorch_call_delegate(lowered_module, *args_data), - ) - return super().call_delegate(lowered_module, args, kwargs, meta) - - # pyre-ignore - def output(self, results, meta): - # pyre-ignore - def get_spec(x): - if isinstance(x, ProxyValue): - return x.node.meta["spec"] - else: - return make_spec(x) - - meta["spec"] = pytree.tree_map(get_spec, results) - return super().output(results, meta) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 14f105e8205..d398b81ee8f 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -74,6 +74,7 @@ from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass from executorch.exir.program._program import lift_constant_tensor_pass from executorch.exir.schema import TensorShapeDynamism +from executorch.exir.sym_util import eval_upper_bound from executorch.exir.tensor import TensorSpec from executorch.exir.tests.common import register_additional_test_aten_ops from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic @@ -113,6 +114,7 @@ def collect_ops(gm: torch.fx.GraphModule): lib.define("foo(Tensor self) -> (Tensor, Tensor)") lib.define("add_relu(Tensor self, Tensor other) -> Tensor") +lib.define("unbacked(Tensor self) -> Tensor") @impl(lib, "foo", "CompositeExplicitAutograd") @@ -132,6 +134,29 @@ def foo_out( return a + 1, None +@impl(lib, "unbacked", "CPU") +def unbacked(a: torch.Tensor) -> torch.Tensor: + return a[: a[0]] + + +@torch.library.register_fake(f"{lib.ns}::unbacked") +def meta_unbacked(x): + ctx = torch._custom_ops.get_ctx() + out_size = ctx.create_unbacked_symint() + return x.new_empty(out_size) + + +lib.define("unbacked.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)") + + +@impl(lib, "unbacked.out", "CPU") +def unbacked_out( + x: torch.Tensor, + out: torch.Tensor, +) -> torch.Tensor: + return out.copy_(x[: x[0]]) + + def simple_promote_dtype( dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND ) -> torch.dtype: @@ -611,6 +636,115 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: self.assertEqual(counter, 1) + def test_spec_prop_pass_unbacked_symint(self) -> None: + # Verify that the spec prop pass picks up on guards for + # unbacked symints. + class Unbacked(torch.nn.Module): + def forward(self, x): + output = torch.ops.DO_NOT_USE_TEST_ONLY.unbacked(x) + torch._constrain_as_size(output.shape[0], max=10) + return output + + model = Unbacked() + gm = ( + to_edge(export(model, (torch.LongTensor([5, 4, 3, 2, 1, 0, 1, 2]),))) + .exported_program() + .graph_module + ) + new_gm = SpecPropPass()(gm) + self.assertIsNotNone(new_gm) + + # Check the spec for the custom op node. It should have a max size of 10. + op_node = next( + n + for n in new_gm.graph_module.graph.nodes + if n.target == exir_ops.edge.DO_NOT_USE_TEST_ONLY.unbacked.default + ) + self.assertIsNotNone(op_node) + + spec: TensorSpec = op_node.meta["spec"] + self.assertEqual(len(spec.shape), 1) # Should be rank 1 + upper_bound = eval_upper_bound(spec.shape[0]) + self.assertEqual(upper_bound, 10) # Should be a concrete value + + def test_spec_prop_pass_cond(self) -> None: + class ModelWithCond(torch.nn.Module): + def true_fn(self, val): + return val * 2 + + def false_fn(self, val): + return val + 1 + + def forward(self, x): + return torch.cond(x[0] > 0, self.true_fn, self.false_fn, [x]) + + model = ModelWithCond() + inputs = (torch.ones(10),) + dynamic_shapes = {"x": {0: torch.export.Dim("x", min=1, max=20)}} + + # Run the spec prop pass and sanity check the spec on the cond. + edge_program = to_edge(export(model, inputs, dynamic_shapes=dynamic_shapes)) + gm = edge_program.exported_program().graph_module + new_gm = SpecPropPass()(gm) + self.assertIsNotNone(new_gm) + + # Check the spec for the cond node. It should have a max size of 20 (matching the dynamic shape upper bound). + cond_node = next( + n + for n in new_gm.graph_module.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "cond" + ) + self.assertIsNotNone(cond_node) + + # Spec for the cond node should be a single-element tuple + spec: tuple[TensorSpec] = cond_node.meta["spec"] + self.assertTrue(isinstance(spec, tuple)) + self.assertEqual(len(spec), 1) + + self.assertEqual(len(spec[0].shape), 1) # Should be rank 1 + upper_bound = eval_upper_bound(spec[0].shape[0]) + self.assertEqual(upper_bound, 20) # Should match dynamic shape bound + + def test_spec_prop_pass_while(self) -> None: + class ModelWithWhile(torch.nn.Module): + def forward(self, i): + def loop_cond(i, acc): + return i[0] > 0 + + def loop_body(i, acc): + return i - 1, acc + i + + _, acc = torch._higher_order_ops.while_loop( + loop_cond, loop_body, (i, torch.zeros(10)) + ) + return acc + + model = ModelWithWhile() + inputs = (torch.Tensor([5]),) + + # Run the spec prop pass and sanity check the spec on the while. + edge_program = to_edge(export(model, inputs)) + gm = edge_program.exported_program().graph_module + new_gm = SpecPropPass()(gm) + self.assertIsNotNone(new_gm) + + # Check the spec for the while node. It should have a max size of 10 (matching the torch.zeros(10) in the model). + while_node = next( + n + for n in new_gm.graph_module.graph.nodes + if hasattr(n.target, "name") and n.target.name() == "while_loop" + ) + self.assertIsNotNone(while_node) + + # Spec for the while node should be a two-element tuple + spec: tuple[TensorSpec] = while_node.meta["spec"] + self.assertTrue(isinstance(spec, tuple)) + self.assertEqual(len(spec), 2) + + self.assertEqual(len(spec[1].shape), 1) # Should be rank 1 + upper_bound = eval_upper_bound(spec[1].shape[0]) + self.assertEqual(upper_bound, 10) + def test_compile_fix_broken_ops(self) -> None: class ExportableLoop(nn.Module): def __init__(self, hidden_size, out_channels):