From 4c0473225bf421c126e0ef4662685c40f32e0044 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Thu, 26 Jun 2025 20:14:04 -0700 Subject: [PATCH] Fix memory planning algo for blocked mem IDs. (#11969) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11969 Fixes handling of block mem IDs for specs. Reviewed By: mcremon-meta Differential Revision: D77310021 --- backends/cadence/aot/memory_constraints.py | 30 +++++--- backends/cadence/aot/memory_planning.py | 20 ++--- backends/cadence/aot/memory_planning_algo.py | 34 ++++---- .../cadence/aot/tests/test_memory_passes.py | 77 ++++++++++++++++++- 4 files changed, 118 insertions(+), 43 deletions(-) diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index 1c0ae9f31a4..62eeb80fd65 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -11,7 +11,7 @@ import typing from collections import defaultdict from dataclasses import dataclass -from typing import cast, DefaultDict, Iterable, Optional, Sequence +from typing import Callable, cast, DefaultDict, Iterable, Optional, Sequence, TypeAlias import torch import torch.fx @@ -573,23 +573,34 @@ def compute_slice_and_select_loc_constraints( graph_module.recompile() +ConstraintsGenPass: TypeAlias = Callable[ + [MemConstraints], + Callable[[torch.fx.GraphModule], Optional[PassResult]], +] + + # The class to generate all the constraints that will be passed on to the memory # planning algorithm. class GenerateMemConstraints: def __init__( self, mem_constraints: MemConstraints, - additional_constraint_gen_passes: list | None = None, + additional_constraint_gen_passes: Sequence[ConstraintsGenPass] | None = None, ) -> None: - self.mem_constraints = mem_constraints - self.additional_constraint_gen_passes = additional_constraint_gen_passes or [] + self.mem_constraints: MemConstraints = mem_constraints + self.additional_constraint_gen_passes: Sequence[ConstraintsGenPass] = ( + additional_constraint_gen_passes or [] + ) def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: - constraint_gen_passes: list = [ - GenerateMemoryViewConstraints, - GenerateSliceAndSelectNopConstraints, - GenerateCatNopConstraints, - ] + self.additional_constraint_gen_passes + constraint_gen_passes: Sequence[ConstraintsGenPass] = cast( + list[ConstraintsGenPass], + [ + GenerateMemoryViewConstraints, + GenerateSliceAndSelectNopConstraints, + GenerateCatNopConstraints, + ], + ) + list(self.additional_constraint_gen_passes) # Create a filter using the opt level in mem_constraints, and filter # the relevant passes. pass_filter = create_cadence_pass_filter(self.mem_constraints.opt_level) @@ -602,6 +613,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], ] ], + # pyre-ignore[6]: Incompatible parameter type. list(filter(pass_filter, constraint_gen_passes)), ) ] diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index 5a7f6e936fb..8baaaa203d0 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -9,11 +9,12 @@ import collections import itertools import logging -from typing import Callable, Iterable, List, Optional, Set, Tuple, TypeAlias +from typing import Iterable, List, Optional, Sequence, Set, Tuple import torch from executorch.backends.cadence.aot.memory_constraints import MemConstraints from executorch.backends.cadence.aot.memory_planning_algo import ( + ConstraintsGenPass, get_aligned_offset, MemoryPlanningAlgo, MemoryPlanningState, @@ -126,10 +127,9 @@ def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: prev_offset, ) if spec.mem_offset is None: - if get_aligned_offset( - prev_offset + spec.allocated_memory, - self.get_alignment(spec.mem_id), - ) > self.get_size(spec.mem_id): + spec.mem_offset = prev_offset + if not self.is_valid_placement(spec): + spec.mem_offset = None continue else: spec.mem_offset = prev_offset @@ -344,12 +344,6 @@ def print_memory_planning_info( ) -ConstraintGenPassType: TypeAlias = Callable[ - [MemConstraints], - Callable[[torch.fx.GraphModule], Optional[PassResult]], -] - - class CadenceMemoryPlanning: def __init__( self, @@ -358,7 +352,7 @@ def __init__( mem_algo: int, alloc_graph_input: bool = True, alloc_graph_output: bool = True, - additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]] = None, + additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None, ) -> None: self.memory_config = memory_config self.opt_level = opt_level @@ -379,7 +373,7 @@ def get_mem_algos( opt_level: int, alloc_graph_input: bool, alloc_graph_output: bool, - additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]], + additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]], ) -> list[MemoryPlanningAlgo]: return [ PositionBasedGreedyWithHierarchy( diff --git a/backends/cadence/aot/memory_planning_algo.py b/backends/cadence/aot/memory_planning_algo.py index 5b67cc6c5fd..ffff2e6aab1 100644 --- a/backends/cadence/aot/memory_planning_algo.py +++ b/backends/cadence/aot/memory_planning_algo.py @@ -5,16 +5,16 @@ import logging import math from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Optional, Sequence import torch from executorch.backends.cadence.aot.memory_constraints import ( + ConstraintsGenPass, GenerateMemConstraints, MemConstraints, ) from executorch.backends.cadence.aot.utils import MemoryConfig from executorch.exir.memory_planning import Verifier -from executorch.exir.pass_base import PassResult from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature @@ -68,18 +68,13 @@ def __init__( self, memory_config: MemoryConfig, placement_constraints: MemConstraints, - additional_constraint_gen_passes: Optional[ - list[ - Callable[ - [MemConstraints], - Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ] - ] = None, + additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None, ) -> None: - self.memory_config = memory_config - self.placement_constraints = placement_constraints - self.additional_constraint_gen_passes = additional_constraint_gen_passes + self.memory_config: MemoryConfig = memory_config + self.placement_constraints: MemConstraints = placement_constraints + self.additional_constraint_gen_passes: Optional[ + Sequence[ConstraintsGenPass] + ] = additional_constraint_gen_passes def get_num_memories(self) -> int: """Get num memories indexed from 1..N, compatible with EXIR's spec.mem_id.""" @@ -102,10 +97,14 @@ def populate_constraints(self, graph_module: torch.fx.GraphModule) -> None: )(graph_module) def is_valid_placement(self, spec: TensorSpec) -> bool: - return get_aligned_offset( + """Returns true if the spec can be placed at the given memory id.""" + end_of_allocation = get_aligned_offset( spec.mem_offset + spec.allocated_memory, self.get_alignment(spec.mem_id), - ) <= self.get_size(spec.mem_id) + ) + return end_of_allocation <= self.get_size( + spec.mem_id + ) and not self.placement_constraints.is_mem_id_in_blocklist(spec, spec.mem_id) @abstractmethod def plan( @@ -133,10 +132,7 @@ def __call__( # First plan the memory allocation for specs without relative constraints. specs_without_relative_constraints = set( filter( - lambda spec: not self.placement_constraints.skipped_spec(spec) - and not self.placement_constraints.is_mem_id_in_blocklist( - spec, spec.mem_id - ), + lambda spec: not self.placement_constraints.skipped_spec(spec), specs, ) ) diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index 73b0cba65ce..df44ded8516 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -8,17 +8,22 @@ import math import unittest -from typing import cast, List, Optional +from typing import cast, List, Optional, Sequence import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.cadence.aot.memory_constraints import ConstraintsGenPass from executorch.backends.cadence.aot.memory_planning import ( CadenceMemoryPlanning, find_peak_memory_usage, ) -from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.pass_utils import ( + CadencePassAttribute, + count_node, + register_cadence_pass, +) from executorch.backends.cadence.aot.typing_stubs import expand from executorch.backends.cadence.aot.utils import ( get_default_memory_config, @@ -26,8 +31,10 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import collect_specs_from_nodes +from executorch.exir.pass_base import PassBase, PassResult from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.tests.models import MultiLayerPerceptron +from parameterized import parameterized from torch.fx import GraphModule @@ -230,6 +237,7 @@ def run_memory_planning( alloc_graph_input: bool = True, alloc_graph_output: bool = True, memory_config: Optional[MemoryConfig] = None, + additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None, ) -> GraphModule: if memory_config is None: memory_config = get_default_memory_config() @@ -240,6 +248,7 @@ def run_memory_planning( mem_algo=mem_algo, alloc_graph_input=alloc_graph_input, alloc_graph_output=alloc_graph_output, + additional_constraint_gen_passes=additional_constraint_gen_passes, )(graph_module).graph_module @expand( @@ -984,3 +993,67 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): ): if spec and spec.mem_offset: self.assertEqual(spec.mem_offset % 37, 0) + + @parameterized.expand([0, 1]) + def test_block_mem_id(self, mem_algo: int) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(16)) + add = builder.call_operator( + op=torch.ops.aten.add.Scalar, + args=(x, 2.0), + ) + mul = builder.call_operator( + op=torch.ops.aten.mul.Scalar, + args=(add, 2.0), + ) + builder.output([mul]) + original = builder.get_graph_module() + + dummy_memory_config = MemoryConfig([1024, 1024, 1024, 1024]) + + add_scalar_block_mem_ids = [2, 3] + mul_scalar_block_mem_ids = [1, 3] + + @register_cadence_pass(CadencePassAttribute(opt_level=0)) + class DummyMemIdBlockConstraintGen(PassBase): + """Blocks placement based on op type. + add: blocks 2, 3 + mul: blocks 1, 3 + + """ + + def __init__(self, memory_constraints: MemoryConfig): + self.memory_constraints = memory_constraints + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten.add.Scalar + ): + spec = node.meta["spec"] + for mem_id in add_scalar_block_mem_ids: + self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id) + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten.mul.Scalar + ): + spec = node.meta["spec"] + for mem_id in mul_scalar_block_mem_ids: + self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id) + + graph_module = self.run_memory_planning( + original, + mem_algo=mem_algo, + memory_config=dummy_memory_config, + additional_constraint_gen_passes=[DummyMemIdBlockConstraintGen], + ) + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten.add.Scalar + ): + spec = node.meta["spec"] + self.assertIsNotNone(spec.mem_id) + self.assertNotIn(spec.mem_id, add_scalar_block_mem_ids) + for node in graph_module.graph.find_nodes( + op="call_function", target=torch.ops.aten.mul.Scalar + ): + spec = node.meta["spec"] + self.assertIsNotNone(spec.mem_id) + self.assertNotIn(spec.mem_id, mul_scalar_block_mem_ids)