Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import typing
from collections import defaultdict
from dataclasses import dataclass
from typing import cast, DefaultDict, Iterable, Optional, Sequence
from typing import Callable, cast, DefaultDict, Iterable, Optional, Sequence, TypeAlias

import torch
import torch.fx
Expand Down Expand Up @@ -573,23 +573,34 @@ def compute_slice_and_select_loc_constraints(
graph_module.recompile()


ConstraintsGenPass: TypeAlias = Callable[
[MemConstraints],
Callable[[torch.fx.GraphModule], Optional[PassResult]],
]


# The class to generate all the constraints that will be passed on to the memory
# planning algorithm.
class GenerateMemConstraints:
def __init__(
self,
mem_constraints: MemConstraints,
additional_constraint_gen_passes: list | None = None,
additional_constraint_gen_passes: Sequence[ConstraintsGenPass] | None = None,
) -> None:
self.mem_constraints = mem_constraints
self.additional_constraint_gen_passes = additional_constraint_gen_passes or []
self.mem_constraints: MemConstraints = mem_constraints
self.additional_constraint_gen_passes: Sequence[ConstraintsGenPass] = (
additional_constraint_gen_passes or []
)

def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
constraint_gen_passes: list = [
GenerateMemoryViewConstraints,
GenerateSliceAndSelectNopConstraints,
GenerateCatNopConstraints,
] + self.additional_constraint_gen_passes
constraint_gen_passes: Sequence[ConstraintsGenPass] = cast(
list[ConstraintsGenPass],
[
GenerateMemoryViewConstraints,
GenerateSliceAndSelectNopConstraints,
GenerateCatNopConstraints,
],
) + list(self.additional_constraint_gen_passes)
# Create a filter using the opt level in mem_constraints, and filter
# the relevant passes.
pass_filter = create_cadence_pass_filter(self.mem_constraints.opt_level)
Expand All @@ -602,6 +613,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
typing.Callable[[torch.fx.GraphModule], Optional[PassResult]],
]
],
# pyre-ignore[6]: Incompatible parameter type.
list(filter(pass_filter, constraint_gen_passes)),
)
]
Expand Down
20 changes: 7 additions & 13 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import collections
import itertools
import logging
from typing import Callable, Iterable, List, Optional, Set, Tuple, TypeAlias
from typing import Iterable, List, Optional, Sequence, Set, Tuple

import torch
from executorch.backends.cadence.aot.memory_constraints import MemConstraints
from executorch.backends.cadence.aot.memory_planning_algo import (
ConstraintsGenPass,
get_aligned_offset,
MemoryPlanningAlgo,
MemoryPlanningState,
Expand Down Expand Up @@ -126,10 +127,9 @@ def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None:
prev_offset,
)
if spec.mem_offset is None:
if get_aligned_offset(
prev_offset + spec.allocated_memory,
self.get_alignment(spec.mem_id),
) > self.get_size(spec.mem_id):
spec.mem_offset = prev_offset
if not self.is_valid_placement(spec):
spec.mem_offset = None
continue
else:
spec.mem_offset = prev_offset
Expand Down Expand Up @@ -344,12 +344,6 @@ def print_memory_planning_info(
)


ConstraintGenPassType: TypeAlias = Callable[
[MemConstraints],
Callable[[torch.fx.GraphModule], Optional[PassResult]],
]


class CadenceMemoryPlanning:
def __init__(
self,
Expand All @@ -358,7 +352,7 @@ def __init__(
mem_algo: int,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]] = None,
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None,
) -> None:
self.memory_config = memory_config
self.opt_level = opt_level
Expand All @@ -379,7 +373,7 @@ def get_mem_algos(
opt_level: int,
alloc_graph_input: bool,
alloc_graph_output: bool,
additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]],
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]],
) -> list[MemoryPlanningAlgo]:
return [
PositionBasedGreedyWithHierarchy(
Expand Down
34 changes: 15 additions & 19 deletions backends/cadence/aot/memory_planning_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
import logging
import math
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Optional, Sequence

import torch
from executorch.backends.cadence.aot.memory_constraints import (
ConstraintsGenPass,
GenerateMemConstraints,
MemConstraints,
)
from executorch.backends.cadence.aot.utils import MemoryConfig
from executorch.exir.memory_planning import Verifier
from executorch.exir.pass_base import PassResult
from executorch.exir.tensor import TensorSpec
from torch.export.exported_program import ExportGraphSignature

Expand Down Expand Up @@ -68,18 +68,13 @@ def __init__(
self,
memory_config: MemoryConfig,
placement_constraints: MemConstraints,
additional_constraint_gen_passes: Optional[
list[
Callable[
[MemConstraints],
Callable[[torch.fx.GraphModule], Optional[PassResult]],
]
]
] = None,
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None,
) -> None:
self.memory_config = memory_config
self.placement_constraints = placement_constraints
self.additional_constraint_gen_passes = additional_constraint_gen_passes
self.memory_config: MemoryConfig = memory_config
self.placement_constraints: MemConstraints = placement_constraints
self.additional_constraint_gen_passes: Optional[
Sequence[ConstraintsGenPass]
] = additional_constraint_gen_passes

def get_num_memories(self) -> int:
"""Get num memories indexed from 1..N, compatible with EXIR's spec.mem_id."""
Expand All @@ -102,10 +97,14 @@ def populate_constraints(self, graph_module: torch.fx.GraphModule) -> None:
)(graph_module)

def is_valid_placement(self, spec: TensorSpec) -> bool:
return get_aligned_offset(
"""Returns true if the spec can be placed at the given memory id."""
end_of_allocation = get_aligned_offset(
spec.mem_offset + spec.allocated_memory,
self.get_alignment(spec.mem_id),
) <= self.get_size(spec.mem_id)
)
return end_of_allocation <= self.get_size(
spec.mem_id
) and not self.placement_constraints.is_mem_id_in_blocklist(spec, spec.mem_id)

@abstractmethod
def plan(
Expand Down Expand Up @@ -133,10 +132,7 @@ def __call__(
# First plan the memory allocation for specs without relative constraints.
specs_without_relative_constraints = set(
filter(
lambda spec: not self.placement_constraints.skipped_spec(spec)
and not self.placement_constraints.is_mem_id_in_blocklist(
spec, spec.mem_id
),
lambda spec: not self.placement_constraints.skipped_spec(spec),
specs,
)
)
Expand Down
77 changes: 75 additions & 2 deletions backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,33 @@

import math
import unittest
from typing import cast, List, Optional
from typing import cast, List, Optional, Sequence

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.memory_constraints import ConstraintsGenPass
from executorch.backends.cadence.aot.memory_planning import (
CadenceMemoryPlanning,
find_peak_memory_usage,
)
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
count_node,
register_cadence_pass,
)
from executorch.backends.cadence.aot.typing_stubs import expand
from executorch.backends.cadence.aot.utils import (
get_default_memory_config,
MemoryConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.memory_planning import collect_specs_from_nodes
from executorch.exir.pass_base import PassBase, PassResult
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.tests.models import MultiLayerPerceptron
from parameterized import parameterized
from torch.fx import GraphModule


Expand Down Expand Up @@ -230,6 +237,7 @@ def run_memory_planning(
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
memory_config: Optional[MemoryConfig] = None,
additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None,
) -> GraphModule:
if memory_config is None:
memory_config = get_default_memory_config()
Expand All @@ -240,6 +248,7 @@ def run_memory_planning(
mem_algo=mem_algo,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
additional_constraint_gen_passes=additional_constraint_gen_passes,
)(graph_module).graph_module

@expand(
Expand Down Expand Up @@ -984,3 +993,67 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
):
if spec and spec.mem_offset:
self.assertEqual(spec.mem_offset % 37, 0)

@parameterized.expand([0, 1])
def test_block_mem_id(self, mem_algo: int) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(16))
add = builder.call_operator(
op=torch.ops.aten.add.Scalar,
args=(x, 2.0),
)
mul = builder.call_operator(
op=torch.ops.aten.mul.Scalar,
args=(add, 2.0),
)
builder.output([mul])
original = builder.get_graph_module()

dummy_memory_config = MemoryConfig([1024, 1024, 1024, 1024])

add_scalar_block_mem_ids = [2, 3]
mul_scalar_block_mem_ids = [1, 3]

@register_cadence_pass(CadencePassAttribute(opt_level=0))
class DummyMemIdBlockConstraintGen(PassBase):
"""Blocks placement based on op type.
add: blocks 2, 3
mul: blocks 1, 3
"""

def __init__(self, memory_constraints: MemoryConfig):
self.memory_constraints = memory_constraints

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Scalar
):
spec = node.meta["spec"]
for mem_id in add_scalar_block_mem_ids:
self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id)
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.mul.Scalar
):
spec = node.meta["spec"]
for mem_id in mul_scalar_block_mem_ids:
self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id)

graph_module = self.run_memory_planning(
original,
mem_algo=mem_algo,
memory_config=dummy_memory_config,
additional_constraint_gen_passes=[DummyMemIdBlockConstraintGen],
)
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Scalar
):
spec = node.meta["spec"]
self.assertIsNotNone(spec.mem_id)
self.assertNotIn(spec.mem_id, add_scalar_block_mem_ids)
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.mul.Scalar
):
spec = node.meta["spec"]
self.assertIsNotNone(spec.mem_id)
self.assertNotIn(spec.mem_id, mul_scalar_block_mem_ids)
Loading