Skip to content

Commit 55a9ea8

Browse files
tarun292facebook-github-bot
authored andcommitted
Adding new memory plannig algorithm heap_optimized_greedy
Summary: This diff adds a new memory planning algorithm called heap_optimized_greedy. The way this functions essentially is that it implements a memory allocation strategy using a greedy approach with a heap to manage memory blocks. The algorithm processes each buffer specification by sorting them based on size and start time. It attempts to fit each buffer into an existing memory block without overlapping in time. If no suitable block is found, a new block is created. The format of the entries in the heap is (end_time, max_size, block_intervals, block_id), so whenever we pop an entry from the heap we're popping the block with the earliest end time. This is different from the existing greedy memory planning algorithm as that one mainly sorts the specs by size and then starts to apply the greedy planning on it (+ also allows overalapping allocations) whereas this one sorts based on size but when we pop entries from the heap we do it based on which entry has ended the earliest. As support was added in the previous diff to iterate through multiple memory planning algorithms, we now iterate through greedy and heap_optimized_greedy for every model and pick the one that gives us the best results. Some example results | Model | Greedy (current) | Heap optimized greedy | | -- | | Llama with spin quant 2k context length |697192000 |659462496 | | ASR encoder |1834640 |1361808 | | Machine translation encoder |203104 |184864 | Differential Revision: D70204407
1 parent fa54844 commit 55a9ea8

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

exir/memory_planning.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import collections
1010
import functools
11+
import heapq
1112
import itertools
1213
import logging
1314
import operator
@@ -707,6 +708,137 @@ def _contains_xnnpack_delegate(graph_module: torch.fx.GraphModule) -> bool:
707708
return True
708709
return False
709710

711+
def heap_optimized_greedy(
712+
graph_module: torch.fx.GraphModule,
713+
alignment: int,
714+
graph_signature: Optional[ExportGraphSignature] = None,
715+
alloc_graph_input: bool = True,
716+
alloc_graph_output: bool = True,
717+
allow_overlapping_allocations: bool = True
718+
) -> MemoryAlgoResult:
719+
"""
720+
This function implements a memory allocation strategy using a greedy approach
721+
with a priority queue (heap) to manage memory blocks.
722+
The algorithm processes each buffer specification by sorting them based on size and
723+
start time. It attempts to fit each buffer into an existing memory block without
724+
overlapping in time. If no suitable block is found, a new block is created.
725+
The format of the entries in the heap is (end_time, max_size, block_intervals, block_id),
726+
so whenever we pop an entry from the heap we're popping the block with the earliest end time.
727+
"""
728+
729+
greedy_result = MemoryAlgoResult({}, [])
730+
731+
extra_padded_bytes = 0
732+
if _contains_xnnpack_delegate(graph_module):
733+
extra_padded_bytes = 64
734+
735+
# Don't do assertion in collect_specs_from_nodes if we have already encountered
736+
# and ignored some to_out_variant errors.
737+
do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False)
738+
739+
specs_list = collect_specs_from_nodes(
740+
graph_module.graph.nodes,
741+
graph_signature,
742+
do_assertion=do_assertion,
743+
ignore_graph_input=not alloc_graph_input,
744+
ignore_graph_output=not alloc_graph_output,
745+
)
746+
# Sort based on the size of the spec, and then the starting lifetime of the spec if
747+
# the size is same.
748+
specs_list = sorted(specs_list, key=lambda x: (-x.allocated_memory, x.lifetime[0]))
749+
750+
# This is a dict where the key is the memory id and the value is the heap
751+
# for that memory id.
752+
# Format of priority queue: (end_time, max_size, block_intervals)
753+
heap_dict = defaultdict(list)
754+
# In this dict we store a mapping from memory id to the max block id that has been
755+
# assigned to that memory id. This is used to assign unique block ids to each block
756+
# in the heap.
757+
block_ids = defaultdict(int)
758+
spec_to_block_id = defaultdict(list)
759+
760+
for spec in specs_list:
761+
spec.realign(alignment)
762+
size, start, end = spec.allocated_memory, spec.lifetime[0], spec.lifetime[1]
763+
assigned = False
764+
765+
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
766+
if spec.mem_id is None:
767+
spec_alloc_result.mem_id = 1
768+
else:
769+
spec_alloc_result.mem_id = spec.mem_id
770+
greedy_result.spec_dict[spec] = spec_alloc_result
771+
772+
# Get the heap for the memory id of the spec.
773+
heap = heap_dict[spec_alloc_result.mem_id]
774+
775+
# Check the heap for compatible blocks
776+
temp = []
777+
while heap:
778+
block_end, block_size, block_intervals, block_id = heapq.heappop(heap)
779+
# Block can fit the buffer if:
780+
# 1. Its max_size >= buffer size
781+
# 2. No overlap with existing intervals
782+
if (block_size >= size and
783+
not any(s < end and start < e for (s, e) in block_intervals)):
784+
# Add buffer to the block
785+
block_intervals.append((start, end))
786+
new_block_end = max(block_end, end)
787+
heapq.heappush(heap, (new_block_end, block_size, block_intervals, block_id))
788+
# Keep track of all the specs that are assigned to this block id.
789+
spec_to_block_id[block_id] += [spec]
790+
assigned = True
791+
break
792+
else:
793+
# If the block is not compatible, add it to a temporary list so that
794+
# we can restore it to the heap later.
795+
temp.append((block_end, block_size, block_intervals, block_id))
796+
797+
# Restore popped blocks to the heap
798+
for item in temp:
799+
heapq.heappush(heap, item)
800+
801+
# Create a new block if no existing block fits
802+
if not assigned:
803+
# Get max block id assigned till now for this memory id.
804+
block_id = block_ids.get(spec_alloc_result.mem_id, 0)
805+
new_block = (end, size, [(start, end)], block_id)
806+
# Add this spec to the list of specs assigned to this block id.
807+
spec_to_block_id[block_id] += [spec]
808+
# Increment the max block id assigned for this memory id.
809+
block_ids[spec_alloc_result.mem_id] += 1
810+
heapq.heappush(heap, new_block)
811+
812+
# Now that we have the heap for each memory id, we can assign offsets to each
813+
# spec based on the heap.
814+
# Format of priority queue: (end_time, max_size, block_intervals, block_id)
815+
if len(heap_dict) == 0:
816+
# Cannot find any tensor in the graph that needs to be allocated.
817+
# Return [0, 0] to be consistent with default behavior of naive.
818+
bufsize = [0, 0]
819+
else:
820+
bufsize = [0] * (max(heap_dict.keys()) + 1)
821+
for mem_id, heap in heap_dict.items():
822+
input_total_size = 0
823+
total_size = 0
824+
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
825+
# pyre-fixme[6]: For 1st argument expected
826+
# `pyre_extensions.ReadOnly[Sized]` but got `Union[Tensor, Module]`.
827+
if len(bufsizes) > mem_id:
828+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.Ten...
829+
input_total_size = bufsizes[mem_id]
830+
while heap:
831+
block_end, block_size, block_intervals, block_id = heapq.heappop(heap)
832+
spec_list = spec_to_block_id[block_id]
833+
for spec in spec_list:
834+
spec_alloc_result = greedy_result.spec_dict.get(spec, None)
835+
assert spec_alloc_result is not None, f"Spec {spec} not found."
836+
spec_alloc_result.mem_offset = total_size
837+
total_size += block_size
838+
bufsize[mem_id] = input_total_size + total_size + extra_padded_bytes
839+
840+
greedy_result.bufsizes = bufsize
841+
return greedy_result
710842

711843
def greedy(
712844
graph_module: torch.fx.GraphModule,
@@ -818,7 +950,7 @@ def memory_planning_algorithm_suite(
818950
alloc_graph_input: bool = True,
819951
alloc_graph_output: bool = True,
820952
allow_overlapping_allocations: bool = True,
821-
algo_list: List[Callable[..., MemoryAlgoResult]] = [greedy],
953+
algo_list: List[Callable[..., MemoryAlgoResult]] = [greedy, heap_optimized_greedy],
822954
) -> List[int]:
823955
r"""
824956
Memory planning algorithm suite that runs a list of memory planning algorithms

exir/tests/test_memory_planning.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
filter_nodes,
2121
get_node_tensor_specs,
2222
greedy,
23+
heap_optimized_greedy,
2324
memory_planning_algorithm_suite,
2425
MemoryAlgoResult,
2526
naive,
@@ -254,6 +255,7 @@ def wrapper(self: "TestMemoryPlanning") -> None:
254255
(naive, False),
255256
# greedy algorithm should reuse tensor storages in the testing model
256257
(greedy, True),
258+
(heap_optimized_greedy, True),
257259
]
258260

259261
for algo, expect_reuse in criteria:
@@ -383,13 +385,15 @@ def verify_overlap_placeholders(
383385
criteria=[
384386
(naive, False),
385387
(greedy, True),
388+
(heap_optimized_greedy, True)
386389
],
387390
)
388391

389392
test_linear_with_view: Callable[..., None] = maketest(
390393
LinearsWithDifferentSizeAndViewOps,
391394
criteria=[
392395
(greedy, True),
396+
(heap_optimized_greedy, True)
393397
],
394398
)
395399

@@ -400,6 +404,7 @@ def verify_overlap_placeholders(
400404
criteria=[
401405
(naive, False),
402406
(greedy, True),
407+
(heap_optimized_greedy, True)
403408
],
404409
extra_check=ModuleListArg.extra_check,
405410
)

0 commit comments

Comments
 (0)