From 6cfaae83ccacc5d7403a75505f4364b22d347784 Mon Sep 17 00:00:00 2001 From: Eashan Garg Date: Mon, 18 Nov 2024 10:14:59 -0800 Subject: [PATCH] Port memory planning to Cadence (#6716) Summary: Porting memory planning over to Cadence OSS Reviewed By: zonglinpeng, hsharma35 Differential Revision: D64406681 --- backends/cadence/aot/memory_planning.py | 365 ++++++++++++++++++++++++ backends/cadence/aot/utils.py | 28 +- 2 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 backends/cadence/aot/memory_planning.py diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py new file mode 100644 index 00000000000..fa5204723cd --- /dev/null +++ b/backends/cadence/aot/memory_planning.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import itertools +import logging +from functools import partial +from typing import Iterable, List, Optional, Tuple + +import torch +from executorch.backends.cadence.aot.utils import MemoryConfig + +from executorch.exir import ExecutorchProgramManager +from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier +from executorch.exir.passes import MemoryPlanningPass +from executorch.exir.tensor import TensorSpec +from tabulate import tabulate +from torch.export.exported_program import ExportGraphSignature +from torch.fx.passes.infra.pass_base import PassResult + + +# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id +def get_num_memories(memory_config: MemoryConfig) -> int: + return len(memory_config.memory_sizes) + 1 + + +# memory_space module provides num_memories indexed 0..num_memories-1. +def get_size(memory_config: MemoryConfig, exir_id: int) -> int: + return memory_config.memory_sizes[exir_id - 1] + + +def collect_specs_from_graph_module( + graph_module: torch.fx.GraphModule, + alloc_graph_input: bool, + alloc_graph_output: bool, +) -> Iterable[TensorSpec]: + """ + Return the specs for all the nodes in the graph module in + topological order. + """ + # Collect the specs from all the nodes in the graph module, and return it + return collect_specs_from_nodes( + graph_module.graph.nodes, + ignore_graph_input=not alloc_graph_input, + ignore_graph_output=not alloc_graph_output, + ) + + +# baseline tensor placement algorithm, that greedily tries to place the tensor in +# the fastest memory available +def position_based_greedy_with_hierarchy( + graph_module: torch.fx.GraphModule, + alignment: int, + graph_signature: ExportGraphSignature, + alloc_graph_input: bool, + alloc_graph_output: bool, + *, + memory_config: MemoryConfig, +) -> List[int]: + num_memories = get_num_memories(memory_config) + bufsizes = [0] * num_memories + allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)] + + def overlap(spec: TensorSpec) -> Optional[TensorSpec]: + for allocated_spec in allocated_buffers[spec.mem_id]: + if Verifier.lifetime_overlap( + spec, allocated_spec + ) and Verifier.storage_overlap(spec, allocated_spec): + return allocated_spec + return None + + def memory_available(spec: TensorSpec) -> bool: + return spec.mem_offset + spec.allocated_memory <= get_size( + memory_config, spec.mem_id + ) + + # Iterate over all the specs in sorted order + for spec in sorted( + collect_specs_from_graph_module( + graph_module, alloc_graph_input, alloc_graph_output + ), + key=lambda spec: spec.allocated_memory, + reverse=True, + ): + for spec.mem_id in range(1, num_memories): + spec.mem_offset = 0 + while memory_available(spec) and (overlapped := overlap(spec)): + spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory + if memory_available(spec): + allocated_buffers[spec.mem_id].append(spec) + bufsizes[spec.mem_id] = max( + spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] + ) + break + if ( + not allocated_buffers[spec.mem_id] + or allocated_buffers[spec.mem_id][-1] is not spec + ): + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") + + logging.debug( + f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}" + ) + return bufsizes + + +# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf +def greedy_by_size_for_offset_calculation_with_hierarchy( + graph_module: torch.fx.GraphModule, + alignment: int, + graph_signature: ExportGraphSignature, + alloc_graph_input: bool, + alloc_graph_output: bool, + *, + memory_config: MemoryConfig, +) -> List[int]: + num_memories = get_num_memories(memory_config) + bufsizes = [0] * num_memories + allocated_buffers = [[] for _ in range(num_memories)] + + # Iterate over all the specs in sorted order + for spec in sorted( + collect_specs_from_graph_module( + graph_module, alloc_graph_input, alloc_graph_output + ), + key=lambda spec: spec.allocated_memory, + reverse=True, + ): + for spec.mem_id in range(1, num_memories): + prev_offset, smallest_gap = 0, float("inf") + for allocated_spec in allocated_buffers[spec.mem_id]: + if Verifier.lifetime_overlap(spec, allocated_spec): + if ( + gap := allocated_spec.mem_offset - prev_offset + ) >= spec.allocated_memory and gap < smallest_gap: + smallest_gap = gap + spec.mem_offset = prev_offset + # Note that different from the paper, which updates prev_offset for all + # allocated tensors, we only update tensors with overlapping lifetime. + # Updating prev_offset outside the if statement will include tensors without + # overlapping lifetime, causing unnecessary waste of memory and make the + # calculation of gap incorrect. Moving it out will make the algorithm degenerate + # to the naive one, reusing 0 tensor. The paper may have a typo here. + prev_offset = max( + allocated_spec.mem_offset + allocated_spec.allocated_memory, + prev_offset, + ) + if spec.mem_offset is None: + if prev_offset + spec.allocated_memory > get_size( + memory_config, spec.mem_id + ): + continue + else: + spec.mem_offset = prev_offset + bufsizes[spec.mem_id] = max( + spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] + ) + allocated_buffers[spec.mem_id].append(spec) + allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) + # A data structure used for maintaining the tensor order + # by offset, named ordered_allocated_ids in the paper + break + if spec not in allocated_buffers[spec.mem_id]: + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") + + logging.debug( + f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}" + ) + return bufsizes + + +def find_peak_memory_usages_per_memory( + graph_module: torch.fx.GraphModule, + alloc_graph_input: bool, + alloc_graph_output: bool, +) -> List[int]: + """ + Given a GraphModule with a memory plan, find the peak memory usages for each memory + in the memory hierarchy. + """ + # Create a defaultdict to keep track of memory usages: {mem_id: mem_usage} + # Use a defaultdict here because we don't know how many unique memory_id in + # the memory hierarchy used in memory planning. + usages = collections.defaultdict(int) + + # go through all nodes in the graph, collect memory usage per spec.mem_id + for spec in collect_specs_from_graph_module( + graph_module, alloc_graph_input, alloc_graph_output + ): + usages[spec.mem_id] = max( + usages[spec.mem_id], spec.mem_offset + spec.allocated_memory + ) + + # Convert usages dictionary into list of len of max memory id + # Ex: {1: 20, 3:30} -> [0, 20, 0, 30]. + # ^ ^ ^ ^ + # | | | |_ mem_id 3 + # | | |_ mem_id 2 + # | |_ mem_id 1 + # |_ mem_id 0 + max_mem_id = max(usages.keys(), default=0) + usages = [usages[i] for i in range(1, max_mem_id + 1)] + + return usages + + +def find_peak_memory_usage( + graph_module: torch.fx.GraphModule, + alloc_graph_input: bool, + alloc_graph_output: bool, +) -> Tuple[int, int]: + """ + Given a GraphModule with a memory plan, find the peak usage over time across all + memories in the memory hierarchy. The resulting peak memory usage should be: + 1. >= min(find_peak_memory_usages_per_memory(graph_module)) + 2. <= sum(find_peak_memory_usages_per_memory(graph_module)) + """ + # memory allocations over time (measured in nodex index) + byte_allocated = [0] * (len(graph_module.graph.nodes) + 1) + + # Iterate over all the node specs + for spec in collect_specs_from_graph_module( + graph_module, alloc_graph_input, alloc_graph_output + ): + if spec.lifetime[0] is None: + continue + + # lifetime is [start, end], both ends inclusive + start, end = spec.lifetime + byte_allocated[start] += spec.allocated_memory + byte_allocated[end + 1] -= spec.allocated_memory + + # accumulate the bytes allocated/deallocated to get memory usages + memory_usages = list(itertools.accumulate(byte_allocated)) + + # find the peak memory usage and the index + peak_memory_usage = max(memory_usages, default=0) + peak_memory_usage_node_idx = ( + memory_usages.index(peak_memory_usage) if memory_usages else 0 + ) + + return peak_memory_usage, peak_memory_usage_node_idx + + +# Print two tables with relevant memory planning information +# +# Per Memory Space Usage Table: +# +--------------------------------------+----------------+-----------------------+-----------------------------+ +# | Memory Space | Base Address | Memory Size (Bytes) | Peak Memory Usage (Bytes) | +# +======================================+================+=======================+=============================+ +# | MEMORY SPACE A | 0x57be0000 | 65213 | 64544 | +# | MEMORY SPACE B | 0x57bf0000 | 65521 | 36864 | +# | MEMORY SPACE ... | ... | ... | ... | +# +--------------------------------------+----------------+-----------------------+-----------------------------+ +# +# Total Memory Space Usage Table: +# +-------------------------------------+---------------+---------+ +# | Peak memory usage across all spaces | 2380032 bytes | Node 86 | +# +-------------------------------------+---------------+---------+ +def print_memory_planning_info( + # pyre-fixme[11]: Annotation `ExecutorchProgramManager` is not defined as a type. + executorch_prog: ExecutorchProgramManager, + memory_config: MemoryConfig, + alloc_graph_input: bool, + alloc_graph_output: bool, +) -> None: + # Get the peak memory usages per memory space + peak_memory_usages_per_memory = find_peak_memory_usages_per_memory( + executorch_prog.exported_program().graph_module, + alloc_graph_input, + alloc_graph_output, + ) + + # Create a table of memory spaces and their base addresses, total memory sizes, and peak memory usage + memory_names, base_addrs = memory_config.memory_names, memory_config.base_addrs + memory_usage_table = [ + [ + f"{(i + 1) if memory_names is None else memory_names[i]}", + None if base_addrs is None else hex(base_addrs[i]), + memory_config.memory_sizes[i], + peak_memory_usages_per_memory[i], + ] + for i in range(len(peak_memory_usages_per_memory)) + ] + + # Print the memory usage per memory space as a table + logging.info( + tabulate( + memory_usage_table, + headers=[ + "Memory Space", + "Base Address", + "Memory Size (Bytes)", + "Peak Memory Usage (Bytes)", + ], + tablefmt="outline", + ) + ) + + # Get the total peak memory usage across all memory spaces + total_peak_memory_usage = find_peak_memory_usage( + executorch_prog.exported_program().graph_module, + alloc_graph_input, + alloc_graph_output, + ) + + # Create a table with total peak memory usage and node at which this occurs + total_memory_usage_table = [ + [ + "Peak memory usage across all spaces", + f"{total_peak_memory_usage[0]} bytes", + f"Node {total_peak_memory_usage[1]}", + ] + ] + + # Print the total memory usage as a table + logging.info( + tabulate( + total_memory_usage_table, + tablefmt="outline", + ) + ) + + +class CadenceMemoryPlanning: + def __init__( + self, + memory_config: MemoryConfig, + mem_algo: int, + alloc_graph_input: bool = True, + alloc_graph_output: bool = True, + ) -> None: + self._init_mem_algos() + + self.memory_config = memory_config + self.mem_algo = mem_algo + self.alloc_graph_input = alloc_graph_input + self.alloc_graph_output = alloc_graph_output + + def _init_mem_algos(self) -> None: + self.available_mem_algos = [ + position_based_greedy_with_hierarchy, + greedy_by_size_for_offset_calculation_with_hierarchy, + ] + + def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: + algo = partial( + self.available_mem_algos[self.mem_algo], + memory_config=self.memory_config, + ) + # Create the memory planning pass. We allocate memory for input + # (output) tensors if alloc_graph_input (alloc_graph_output) is + # True. + mem_planning = MemoryPlanningPass( + algo, + allow_lifetime_and_storage_overlap=False, + alloc_graph_input=self.alloc_graph_input, + alloc_graph_output=self.alloc_graph_output, + ) + mem_planning(graph_module) + + return PassResult(graph_module, True) diff --git a/backends/cadence/aot/utils.py b/backends/cadence/aot/utils.py index 12eb899d9d8..d0d77bbfb60 100644 --- a/backends/cadence/aot/utils.py +++ b/backends/cadence/aot/utils.py @@ -6,10 +6,12 @@ # pyre-strict +import enum import logging import operator import os -from typing import Dict, List, Tuple +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple import torch @@ -227,3 +229,27 @@ def save_bpte_program( logging.info(f"Saved exported program to {filename}") except Exception as e: logging.error(f"Error while saving to {output_dir}: {e}") + + +@dataclass +class MemoryConfig: + memory_sizes: List[int] + + # Optional fields for logs + memory_names: Optional[List[str]] = None + base_addrs: Optional[List[int]] = None + memory_xml_path: Optional[str] = None + MemorySpace: Optional[enum.Enum] = None + + # get num memories indexed from 1..N, compatible with EXIR's spec.mem_id + def get_num_memories(self) -> int: + return len(self.memory_sizes) + 1 + + # memory_space module provides num_memories indexed 0..num_memories-1. + def get_size(self, exir_id: int) -> int: + return self.memory_sizes[exir_id - 1] + + +# Return default memory config for the backend +def get_default_memory_config() -> MemoryConfig: + return MemoryConfig(memory_sizes=[0x1000000000])