diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 6b2b61729ed..1613cfb28ca 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -433,6 +433,7 @@ python_unittest( srcs = [ "tests/test_memory_passes.py", ], + supports_static_listing = False, typing = True, deps = [ ":compiler", @@ -441,7 +442,9 @@ python_unittest( ":pass_utils", "//caffe2:torch", "//executorch/exir:memory", + "fbsource//third-party/pypi/parameterized:parameterized", "//executorch/exir/dialects:lib", + "//executorch/backends/cadence/aot:graph_builder", "//executorch/exir/tests:models", ], ) diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index 3de140e4647..377e6fc81e6 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -350,14 +350,28 @@ def is_slice_view(self, node: torch.fx.Node) -> bool: def is_cat_along_outermost_dim( self, graph_module: torch.fx.GraphModule, cat_node: torch.fx.Node ) -> bool: + assert len(cat_node.args) > 0 + cat_tensors = cat_node.args[0] + if not isinstance(cat_tensors, Sequence) or not all( + isinstance(t, torch.fx.Node) for t in cat_tensors + ): + raise ValueError("cat_tensors must be a sequence of torch.fx.Node objects.") + + if len(cat_node.args) > 1: + cat_dim = cat_node.args[1] + else: + cat_dim = cat_node.kwargs.get("dim", None) + if not isinstance(cat_dim, int): + raise ValueError("cat_dim must be an integer.") + # If the cat op has default dim, then the concat dim is 0 - if len(cat_node.args) == 1 or cat_node.args[1] == 0: + if len(cat_tensors) == 1 or cat_dim == 0: return True - # Get the concatenation dimension and concatenated tensors - (cat_tensors, cat_dim) = cast( - tuple[Sequence[torch.fx.Node], int], cat_node.args - ) + + # Make sure all dimes before cat_dim are 1. for tensor in cat_tensors: + if not isinstance(tensor, torch.fx.Node): + continue shape = get_shape(graph_module, tensor) if shape is None or not all(dim == 1 for dim in shape[0:cat_dim]): return False diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index c32809c2bff..d220007e227 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -14,13 +14,23 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot import compiler -from executorch.backends.cadence.aot.memory_planning import find_peak_memory_usage +from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.cadence.aot.memory_planning import ( + CadenceMemoryPlanning, + find_peak_memory_usage, +) from executorch.backends.cadence.aot.pass_utils import count_node -from executorch.backends.cadence.aot.utils import MemoryConfig +from executorch.backends.cadence.aot.utils import ( + get_default_memory_config, + MemoryConfig, +) from executorch.exir import memory from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import collect_specs_from_nodes +from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.tests.models import MultiLayerPerceptron +from parameterized.parameterized import parameterized +from torch.fx import GraphModule class TestMemPlanningPasses(unittest.TestCase): @@ -120,24 +130,27 @@ def forward(self, x): class TestMemTransform(unittest.TestCase): def _verify_cat_nop_memory_alloc(self, node: torch.fx.Node) -> None: - spec = node.meta.get("spec", None) - self.assertIsNotNone(spec) - dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0 - outer_size = math.prod(spec.shape[:dim]) + node_spec = node.meta.get("spec", None) + self.assertIsNotNone(node_spec) + dim: int = cast(int, node.kwargs["dim"]) if "dim" in node.kwargs else 0 + outer_size = math.prod(node_spec.shape[:dim]) self.assertEqual( outer_size, 1, f"{node=} has wrong outer size: {outer_size=}, expected 1.", ) - inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize + inner_dim_elements = ( + math.prod(node_spec.shape[dim + 1 :]) * node_spec.dtype.itemsize + ) dim_offset = 0 for arg in cast(list[torch.fx.Node], node.args[0]): arg_spec = arg.meta.get("spec", None) - self.assertEqual(arg_spec.mem_id, spec.mem_id) + self.assertEqual(arg_spec.mem_id, node_spec.mem_id) + actual_offset = node_spec.mem_offset + dim_offset * inner_dim_elements self.assertEqual( arg_spec.mem_offset, - spec.mem_offset + dim_offset * inner_dim_elements, - f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {dim_offset=} for cat on {dim=}, but output has {spec.mem_offset=}", + actual_offset, + f"{arg=} of node {node=} has wrong memory offset: expected {arg_spec.mem_offset=}, but got {actual_offset=} = {node_spec.mem_offset=} + {dim_offset=} * {inner_dim_elements=}", ) dim_offset += arg_spec.shape[dim] @@ -209,23 +222,45 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None: ): self._verify_select_nop_memory_alloc(node) - def test_optimize_cat_on_placeholders(self) -> None: - class Cat(torch.nn.Module): - def forward(self, x, y): - return torch.ops.aten.cat((x, y)) - - x = torch.ones(3, 6) - y = torch.ones(2, 6) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - Cat(), (x, y), opt_level=2, mem_algo=1 - ) - .exported_program() - .graph_module - ) - logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}") + # Initializes the nodes metadata and runs the GenerateMemoryViewConstraints, + # GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes. + def run_memory_planning(self, original, alloc_graph_input=True) -> GraphModule: + graph_module = SpecPropPass().call(original).graph_module + return CadenceMemoryPlanning( + get_default_memory_config(), + opt_level=2, + mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy + alloc_graph_input=alloc_graph_input, + )(graph_module).graph_module + + @parameterized.expand( + [ + [ + [3, 6], # x_shape + [2, 6], # y_shape + 0, # concat dim + ], + ] + ) + def test_optimize_cat_on_placeholders(self, x_shape, y_shape, concat_dim) -> None: + concat_shape = [x_shape[concat_dim] + y_shape[concat_dim], x_shape[1]] + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(*x_shape)) + y = builder.placeholder("y", torch.ones(*y_shape)) + pre_created_output = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(concat_shape, 0.0), + kwargs={"dtype": torch.float32}, + ) + graph_output = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([x, y],), + kwargs={"dim": concat_dim, "out": pre_created_output}, + ) + builder.output([graph_output]) + original = builder.get_graph_module() + + graph_module = self.run_memory_planning(original) graph_module.graph.eliminate_dead_code() # Assert that cat op is optimized away self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) @@ -233,53 +268,88 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_outermost(self) -> None: - class OptimizeCatFeasible1(torch.nn.Module): - def forward(self, x, y): - x1 = torch.add(x, 2.4, 3.1) - y1 = torch.add(y, 1, 2) - # Cat along the outermost dimension can be optimized away after - # adding constraints on the locations of x1 and y1. - return torch.ops.aten.cat((x1, y1)) - - x = torch.ones(3, 6) - y = torch.ones(2, 6) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - OptimizeCatFeasible1(), (x, y), opt_level=2, mem_algo=1 + # Returns a GraphModule with the following structure: + # "add_add_cat_model" : cat(x + 123, y + 456) + # "add_add_cat_add_model": cat(x + 123, y + 456) + 789 + def get_graph_module( + self, model_name, x_shape, y_shape, concated_shape, concat_dim + ) -> GraphModule: + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(*x_shape, dtype=torch.float32)) + y = builder.placeholder("y", torch.ones(*y_shape, dtype=torch.float32)) + to_add_to_x = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(x_shape, 123.0), + kwargs={"dtype": torch.float32}, + ) + add_x = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add_to_x), + ) + to_add_to_y = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(y_shape, 456.0), + kwargs={"dtype": torch.float32}, + ) + add_y = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(y, to_add_to_y), + ) + pre_created_output = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(concated_shape, 0.0), + kwargs={"dtype": torch.float32}, + ) + cat = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([add_x, add_y],), + kwargs={"dim": concat_dim, "out": pre_created_output}, + ) + if model_name == "add_add_cat_model": + builder.output([cat]) + return builder.get_graph_module() + + if model_name == "add_add_cat_add_model": + to_add_to_cat = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(concated_shape, 789.0), + kwargs={"dtype": torch.float32}, ) - .exported_program() - .graph_module - ) - graph_module.graph.eliminate_dead_code() - # Assert that cat op is optimized away - self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) - # Assert that cat op is replaced by its nop version post optimization - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) - self.verify_nop_memory_alloc(graph_module) - - def test_optimize_cat_non_outermost(self) -> None: - class OptimizeCatFeasible2(torch.nn.Module): - def forward(self, x, y): - x1 = torch.add(x, 2.4, 3.1) - y1 = torch.add(y, 1, 2) - # Cat along the outermost dimension can be optimized away after - # adding constraints on the locations of x1 and y1. - return torch.ops.aten.cat((x1, y1), 1) - - x = torch.ones(1, 3, 6) - y = torch.ones(1, 2, 6) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - OptimizeCatFeasible2(), (x, y), opt_level=2, mem_algo=1 + graph_output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(cat, to_add_to_cat), ) - .exported_program() - .graph_module - ) + builder.output([graph_output]) + return builder.get_graph_module() + + raise ValueError(f"Unknown model name {model_name}") + + @parameterized.expand( + [ + ( + "outermost", + [3, 6], # x_shape + [2, 6], # y_shape + [5, 6], # concated_shape + 0, # concat dim + ), + ( + "non_outermost", + [1, 3, 6], # x_shape + [1, 2, 6], # y_shape + [1, 5, 6], # concated_shape + 1, # concat dim + ), + ], + name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", + ) + def test_cat_optimized( + self, _, x_shape, y_shape, concated_shape, concat_dim + ) -> None: + original = self.get_graph_module( + "add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim + ) + graph_module = self.run_memory_planning(original) graph_module.graph.eliminate_dead_code() # Assert that cat op is optimized away self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) @@ -287,111 +357,181 @@ def forward(self, x, y): self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_no_optimize_cat_non_outermost(self) -> None: - class OptimizeCatInfeasible1(torch.nn.Module): - def forward(self, x, y): - x1 = torch.add(x, 2.4, 3.1) - y1 = torch.add(y, 1, 2) - # Cat along the outermost dimension can be optimized away after - # adding constraints on the locations of x1 and y1. - return torch.ops.aten.cat((x1, y1), 1) - - x = torch.ones(2, 4, 5) - y = torch.ones(2, 2, 5) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - OptimizeCatInfeasible1(), (x, y), opt_level=2, mem_algo=1 - ) - .exported_program() - .graph_module - ) + @parameterized.expand( + [ + ( + "non_outermost", + [2, 4, 5], # x_shape + [2, 2, 5], # y_shape + [2, 6, 5], # concated_shape + 1, # concat dim + ), + ], + name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", + ) + def test_cat_not_optimized( + self, _, x_shape, y_shape, concated_shape, concat_dim + ) -> None: + original = self.get_graph_module( + "add_add_cat_model", x_shape, y_shape, concated_shape, concat_dim + ) + graph_module = self.run_memory_planning(original) graph_module.graph.eliminate_dead_code() - # Assert that cat op is not optimized away, since the concat is not - # along the outermost dim + # Assert that cat op is not optimized away, since the concat is not along the outermost dim. + # The first dimension is 2, but all dims before cat_dim should be == 1. self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_no_optimize_cat_non_outermost1(self) -> None: - class OptimizeCatInfeasible2(torch.nn.Module): - def forward(self, x, y): - x1 = torch.add(x, 2.4, 3.1) - y1 = torch.add(y, 1, 2) - # Cat along the outermost dimension can be optimized away after - # adding constraints on the locations of x1 and y1. - return torch.ops.aten.cat((x1, y1), 0) + 2 + @parameterized.expand( + [ + ( + "aligned", + [5, 8], # x_shape + [3, 8], # y_shape + [8, 8], # concated_shape + 0, # concat dim + 0, # expected cat nodes + ), + ( + "unaligned", # 5 * 5 * 4 % 8 != 0 + [5, 5], # x_shape + [3, 5], # y_shape + [8, 5], # concated_shape + 0, # concat dim + 1, # expected cat nodes + ), + ], + name_func=lambda f, _, param: f"{f.__name__}_{param.args[0]}", + ) + def test_cat_not_graph_output( + self, _, x_shape, y_shape, concated_shape, concat_dim, expected_cat_nodes + ) -> None: + original = self.get_graph_module( + "add_add_cat_add_model", x_shape, y_shape, concated_shape, concat_dim + ) + graph_module = self.run_memory_planning(original) + graph_module.graph.eliminate_dead_code() - x = torch.ones(5, 5) - y = torch.ones(3, 5) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - OptimizeCatInfeasible2(), (x, y), opt_level=2, mem_algo=1 - ) - .exported_program() - .graph_module + # Assert that cat op is optimized away only if its arguments offsets are multiple of 8 bytes. + self.assertEqual( + count_node(graph_module, torch.ops.aten.cat.out), expected_cat_nodes ) - graph_module.graph.eliminate_dead_code() - # Assert that cat op is not optimized away, since the concat relative - # offsets are not multiple of 8 bytes, and the cat is not the output - # of the graph. - self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) def test_optimize_cat_with_slice(self) -> None: - class OptimizeCatSliceFeasible(torch.nn.Module): - def forward(self, x): - x1 = torch.add(x, 2.4, 3.1) - x2 = torch.ops.aten.slice(x, 0, 0, 1) - x3 = torch.ops.aten.cat((x1, x2)) - return torch.add(x3, x3) - - x = torch.randn(5, 6) - # Compile, and set alloc_graph_input to False so that slice op is not - # optimized away. - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - OptimizeCatSliceFeasible(), - (x,), - opt_level=2, - mem_algo=1, - alloc_graph_input=False, - ) - .exported_program() - .graph_module - ) + x_shape = [5, 6] + concated_shape = [6, 6] + concat_dim = 0 + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(*x_shape, dtype=torch.float32)) + to_add_to_x = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(x_shape, 123.0), + kwargs={"dtype": torch.float32}, + ) + add_x = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add_to_x), + ) + slice_x = builder.call_operator( + op=exir_ops.edge.aten.slice.Tensor, + args=( + x, + 0, # dim + 0, # start + 1, # end + 1, # step + ), + ) + pre_created_output = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(concated_shape, 0.0), + kwargs={"dtype": torch.float32}, + ) + cat = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([add_x, slice_x],), + kwargs={"dim": concat_dim, "out": pre_created_output}, + ) + graph_output = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(cat, cat), + ) + builder.output([graph_output]) + original = builder.get_graph_module() + + graph_module = self.run_memory_planning(original, alloc_graph_input=False) graph_module.graph.eliminate_dead_code() - # Assert that cat op is optimized away + + # Assert that cat op is optimized away. + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) + # Assert that cat op is replaced by its nop version post optimization. self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + # Assert that slice op was not optimized away. + self.assertEqual(count_node(graph_module, exir_ops.edge.aten.slice.Tensor), 1) self.verify_nop_memory_alloc(graph_module) def test_optimize_cat_with_slice_infeasible(self) -> None: - class OptimizeCatSliceInfeasible(torch.nn.Module): - def forward(self, x, y): - x1 = torch.add(x, 2.4, 3.1) - y1 = torch.add(y, 1, 2) - y2 = torch.ops.aten.slice(y1, 0, 0, 1) - # Cat can't be optimized away if any of the tensor (e.g., y1) - # is slice_nop - return torch.ops.aten.cat((y2, x1)) - - x = torch.ones(3, 5) - y = torch.ones(2, 5) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - OptimizeCatSliceInfeasible(), (x, y), opt_level=2, mem_algo=1 - ) - .exported_program() - .graph_module - ) + x_shape = [5, 6] + y_shape = [3, 6] + concated_shape = [8, 6] + concat_dim = 0 + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(*x_shape, dtype=torch.float32)) + y = builder.placeholder("y", torch.ones(*y_shape, dtype=torch.float32)) + to_add_to_x = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(x_shape, 123.0), + kwargs={"dtype": torch.float32}, + ) + add_x = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add_to_x), + ) + to_add_to_y = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(y_shape, 123.0), + kwargs={"dtype": torch.float32}, + ) + add_y = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(y, to_add_to_y), + ) + slice_out = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(y_shape, 0.0), + kwargs={"dtype": torch.float32}, + ) + slice_y = builder.call_operator( + op=torch.ops.aten.slice_copy.Tensor_out, + args=( + add_y, + 0, # dim + 0, # start + 1, # end + 1, # step + ), + kwargs={"out": slice_out}, + ) + pre_created_output = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(concated_shape, 0.0), + kwargs={"dtype": torch.float32}, + ) + cat = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([slice_y, add_x],), + kwargs={"dim": concat_dim, "out": pre_created_output}, + ) + builder.output([cat]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning(original, alloc_graph_input=False) graph_module.graph.eliminate_dead_code() - # Assert that cat op is not optimized away + # # Assert that slice op is optimized away. + self.assertEqual( + count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1 + ) + # # Assert that cat op is not optimized away self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module)