Skip to content

Commit 69039f0

Browse files
committed
[ET][Memory planning] Improve greedy memory planning.
Pull Request resolved: #7926 This diff replaces the old greedy algorithm. Older algorithm resulted in 35% worse compared to theoretical optimum. THis matter for long context even more since additional overhead can be few hundred MB. For example the theorical optimial for llama3_2 8B, 4-bit quantized modelw ith context length of 2k needs about 1G of memory. This theoretcial max can be observed by looking at the peaks in memory profile. Current agorithm resulted in about 1.6GB of planned memory. New algorithm reduce that to about 1.1G. ghstack-source-id: 262945660 @exported-using-ghexport Differential Revision: [D68448332](https://our.internmc.facebook.com/intern/diff/D68448332/)
1 parent c7c4007 commit 69039f0

File tree

4 files changed

+226
-29
lines changed

4 files changed

+226
-29
lines changed

exir/memory_planning.py

Lines changed: 161 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import operator
1212
import typing
1313
from collections import defaultdict
14-
from dataclasses import dataclass
14+
from dataclasses import dataclass, field
1515
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
1616

1717
import torch
@@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
117117

118118
return has_overlap
119119

120+
@classmethod
121+
def _debug_message_from_specs(
122+
cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec
123+
) -> str:
124+
message = (
125+
f"lhs life time: {lhs_spec.lifetime}, rhs lifetime: {rhs_spec.lifetime} "
126+
)
127+
message += f"lhs: mem_id {lhs_spec.mem_id} storage: {lhs_spec.mem_offset}, {lhs_spec.allocated_memory} "
128+
message += f"rhs: mem_id {rhs_spec.mem_id} storage: {rhs_spec.mem_offset}, {rhs_spec.allocated_memory}"
129+
return message
130+
120131
def verify_storage_reuse(
121132
self, allow_lifetime_and_storage_overlap: bool = False
122133
) -> int:
@@ -159,7 +170,7 @@ def verify_storage_reuse(
159170
lhs_spec, rhs_spec
160171
):
161172
raise InternalError(
162-
f"Unexpected storage overlap: lhs {lhs_spec}, rhs {rhs_spec}"
173+
f"Unexpected storage overlap: {Verifier._debug_message_from_specs(lhs_spec, rhs_spec)}"
163174
)
164175

165176
# Check that each mem_obj_id is consistent with whether the tensors have
@@ -454,6 +465,18 @@ def update_all_tensors_lifetime(
454465
return specs
455466

456467

468+
@dataclass
469+
class AllocationSpec:
470+
"""
471+
AllocationSpec is used to represent the allocation of a tensor.
472+
"""
473+
474+
# The offset of the tensor in the shared object/pool.
475+
offset: int
476+
# TensorSpec
477+
spec: TensorSpec
478+
479+
457480
@dataclass
458481
class SharedObject:
459482
r"""
@@ -470,8 +493,15 @@ class SharedObject:
470493
offset: int
471494
# size of this shared object in bytes
472495
size: int
496+
# When the object is first created
497+
first_used_index: int
473498
# the object will be available for index (last_used_index + 1)
474499
last_used_index: int
500+
# list of allocations belong to this shared object
501+
allocations: List[AllocationSpec] = field(default_factory=list)
502+
503+
def __repr__(self) -> str:
504+
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
475505

476506

477507
def materialize_buffer(
@@ -489,35 +519,122 @@ def materialize_buffer(
489519
return total_size
490520

491521

492-
def _size_abs_dif(sobj: SharedObject, spec: TensorSpec) -> int:
522+
def _does_not_overlap(sobj: SharedObject, spec: TensorSpec) -> bool:
493523
r"""
494-
Calculate the absolute different between the size of a shared object and
495-
a tensor.
524+
Check if a shared object and a tensor do not overlap.
496525
"""
497-
return abs(sobj.size - spec.allocated_memory)
526+
for alloc in sobj.allocations:
527+
if not (
528+
spec.lifetime[1] < alloc.spec.lifetime[0]
529+
or spec.lifetime[0] > alloc.spec.lifetime[1]
530+
):
531+
return False
532+
return True
533+
534+
535+
def _find_max_overlapping_allocations_offset(
536+
sobj: SharedObject, spec: TensorSpec
537+
) -> int:
538+
max_offset = 0
539+
for alloc in sobj.allocations:
540+
if (
541+
spec.lifetime[1] < alloc.spec.lifetime[0]
542+
or spec.lifetime[0] > alloc.spec.lifetime[1]
543+
):
544+
continue
545+
max_offset = max(alloc.offset + alloc.spec.allocated_memory, max_offset)
546+
return max_offset
498547

499548

500549
def pick_shared_obj(
501550
shared_objects: List[SharedObject], spec: TensorSpec
502551
) -> SharedObject:
503552
r"""
504-
Pick the available shared object with closest size to the tensor.
505-
If there are no available shared object left, create a new one.
553+
Pick the available shared object to which to assign this spec,
554+
or create a new one
555+
Algorithm details
556+
Previous: Look at every spec in chronological order. Find if previously allocated object
557+
allows it to fit in. If not, allocate a new object.
558+
New:
559+
- Sort all the specs by allocation size
560+
- Process the specs in order
561+
- If the spec's size in smaller than previously allocated buckets:
562+
- Conditions under which previously allocated bucket can be used:
563+
- Lifetime of the spec does not overlap with lifetime of the bucket.
564+
- In this case allocate spec to that bucket and expand its lifetime.
565+
- Spec is allocated at offset = 0 in this bucket.
566+
- Add this spec to allocated object's list of specs.
567+
- Lifetime of the spec overlaps with lifetime of the bucket,
568+
partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
569+
- If none of the specs in the bucket overlaps with spec's lifetime.
570+
- Allocate spec to the bucket at offset = 0.
571+
- Add this spec to the bucket's list of specs.
572+
- Expand bucket's lifetime accounting for added spec's lifetime.
573+
- If one or more specs in the bucket overlaps with spec's lifetime.
574+
- Collect offsets (at which the given overlapping spec is allocated in the bucket).
575+
of all the overlapping specs, and find the max offset.
576+
- Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
577+
- Add this spec to the bucket's list of specs.
578+
- Expand bucket's lifetime accounting for added spec's lifetime.
579+
- If none of these conditions are met, allocate a new bucket.
580+
- Add spec to this bucket.
581+
- Update bucket's lifetime to that of the spec.
582+
- If the spec's size is larger than previously allocated buckets, allocate a new bucket.
583+
- Size and lifetime of this bucket is that of the spec
584+
585+
Proof of correctness:
586+
- If allocating a new bucket, it is correct.
587+
- If allocating spec to an existing bucket, whose lifetime does not overlap with any
588+
of the previously allocated specs' lifetime, then the allocation is correct.
589+
Proof of correctness by induction when adding spec to an existing bucket:
590+
- If all previous allocations in the given bucket are correct:
591+
- Then the new one being added must be correct because when the requested allocation
592+
overlaps with one or more previous allocations, we find the largest offset among
593+
all the overlapping allocations, and allocate the new spec at that offset. Hence,
594+
the allocation at such an offset, will not overlap with any previous allocations.
595+
Base case: A newly added allocation within a bucket with single allocation is correct:
596+
because a) it must fit and b) its lifetime must not overlap with object's lifetime.
597+
This holds true because of the following invariants:
598+
- Once a bucket is created, it is never resized.
599+
- All the allocations within a bucket follow this:
600+
- Span, defined by allocation's offset + size, of two allocations can only overlap,
601+
if their timelines do not overlap.
506602
"""
507-
# TODO: do better than linear scan
508603
picked = None
509604
for sobj in shared_objects:
510-
if spec.lifetime[0] > sobj.last_used_index:
511-
if picked is None or _size_abs_dif(sobj, spec) < _size_abs_dif(
512-
picked, spec
513-
):
514-
picked = sobj
515-
sobj.last_used_index = spec.lifetime[1]
516-
sobj.size = max(sobj.size, spec.allocated_memory)
605+
if _does_not_overlap(sobj, spec):
606+
assert sobj.size >= spec.allocated_memory, "Allocation specs are not sorted"
607+
picked = sobj
608+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
609+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
610+
allocation_spec = AllocationSpec(0, spec)
611+
picked.allocations.append(allocation_spec)
612+
break
613+
614+
if picked is None:
615+
for sobj in shared_objects:
616+
max_offset = _find_max_overlapping_allocations_offset(sobj, spec)
617+
if max_offset > 0:
618+
if max_offset + spec.allocated_memory <= sobj.size:
619+
picked = sobj
620+
sobj.first_used_index = min(sobj.first_used_index, spec.lifetime[0])
621+
sobj.last_used_index = max(sobj.last_used_index, spec.lifetime[1])
622+
allocation_spec = AllocationSpec(max_offset, spec)
623+
picked.allocations.append(allocation_spec)
624+
break
625+
517626
if picked is None:
518627
picked = SharedObject(
519-
len(shared_objects), -1, spec.allocated_memory, spec.lifetime[1]
628+
len(shared_objects),
629+
-1,
630+
spec.allocated_memory,
631+
spec.lifetime[0],
632+
spec.lifetime[1],
520633
)
634+
allocation_spec = AllocationSpec(0, spec)
635+
picked.allocations.append(allocation_spec)
636+
picked.first_used_index = spec.lifetime[0]
637+
picked.last_used_index = spec.lifetime[1]
521638
shared_objects.append(picked)
522639

523640
return picked
@@ -565,13 +682,20 @@ def greedy(
565682
# For each tensor, pick the available shared object with closest size to
566683
# the tensor. If there are no available shared object left, create a new
567684
# one.
685+
import bisect
686+
687+
sorted_specs = []
568688
for spec in collect_specs_from_nodes(
569689
graph_module.graph.nodes,
570690
graph_signature,
571691
do_assertion=do_assertion,
572692
ignore_graph_input=not alloc_graph_input,
573693
ignore_graph_output=not alloc_graph_output,
574694
):
695+
bisect.insort(sorted_specs, spec, key=lambda x: x.allocated_memory)
696+
sorted_specs.reverse()
697+
698+
for spec in sorted_specs:
575699
if spec.mem_id is None:
576700
spec.mem_id = 1
577701
spec.realign(alignment)
@@ -583,6 +707,7 @@ def greedy(
583707
total_sizes = [0, 0]
584708
else:
585709
total_sizes = [0] * (max(shared_objects.keys()) + 1)
710+
num_specs_processed = 0
586711
for mem_id in shared_objects:
587712
input_total_size = 0
588713
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
@@ -594,13 +719,25 @@ def greedy(
594719
total_sizes[mem_id] = materialize_buffer(
595720
shared_objects[mem_id], input_total_size
596721
)
597-
598-
# Since we now know the number of shared objects we need and the size of
599-
# each shared object, we can assign offset in the memory buffer for each
600-
# shared object.
601-
for spec, sobj in spec2obj.items():
602-
spec.mem_obj_id = sobj.idx
603-
spec.mem_offset = sobj.offset
722+
# padding allocation with 64 bytes.
723+
# this requirement really for XNNPACK backend which can access tensors
724+
# for reading beyond the end of the tensor. This is done for performance
725+
# optimizations in XNNPACK.
726+
# While account for backend specific requirement is not the right choice
727+
# in backend agnostic memory planning, we do it here for now.
728+
total_sizes[mem_id] += 64
729+
# Since we now know the number of shared objects we need and the size of
730+
# each shared object, we can assign offset in the memory buffer for each
731+
# shared object.
732+
for sobj in shared_objects[mem_id]:
733+
for alloc in sobj.allocations:
734+
spec = alloc.spec
735+
alloc.spec.mem_obj_id = sobj.idx
736+
alloc.spec.mem_offset = sobj.offset + alloc.offset
737+
num_specs_processed += 1
738+
assert (
739+
len(spec2obj) == num_specs_processed
740+
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
604741

605742
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
606743
return total_sizes

exir/passes/memory_planning_pass.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import logging
88
import warnings
9-
from typing import Callable, List, Optional
9+
from typing import Any, Callable, List, Optional
10+
from functools import partial
1011

1112
import torch
1213
from executorch.exir.error import internal_assert
@@ -24,6 +25,17 @@
2425
from torch.export.exported_program import ExportGraphSignature
2526

2627

28+
# copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function
29+
def _callable_name(any_callable: Callable[..., Any]) -> str:
30+
if isinstance(any_callable, partial):
31+
return any_callable.func.__name__
32+
33+
try:
34+
return any_callable.__name__
35+
except AttributeError:
36+
return str(any_callable)
37+
38+
2739
class MemoryPlanningPass(PassBase):
2840
def __init__(
2941
self,
@@ -127,4 +139,12 @@ def run(
127139
f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors"
128140
)
129141
verifier.verify_graph_input_output()
142+
if (
143+
callable(self.memory_planning_algo)
144+
and _callable_name(self.memory_planning_algo) == "greedy"
145+
):
146+
# Only verify storage reuse for greedy algorithm
147+
# At the moment cadence backends memory planning fails this
148+
# I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function
149+
verifier.verify_storage_reuse()
130150
return PassResult(graph_module, True)

exir/tests/test_joint_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,13 @@ def forward(self, x, y):
8484
et.executorch_program.execution_plan[0]
8585
.values[0]
8686
.val.allocation_info.memory_offset_low,
87-
0,
87+
96,
8888
)
8989
self.assertEqual(
9090
et.executorch_program.execution_plan[0]
9191
.values[1]
9292
.val.allocation_info.memory_offset_low,
93-
48,
93+
224,
9494
)
9595

9696
loss = m(*example_inputs)

exir/tests/test_memory_planning.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,28 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
106106
return (torch.randn(2),)
107107

108108

109+
class LinearsWithDifferentSizeAndViewOps(torch.nn.Module):
110+
def __init__(self) -> None:
111+
super(LinearsWithDifferentSizeAndViewOps, self).__init__()
112+
self.linears = torch.nn.ModuleList()
113+
for x in [8, 16, 32, 64]:
114+
self.linears.append(torch.nn.Linear(x, x * 2))
115+
116+
def forward(self, i: torch.Tensor) -> torch.Tensor:
117+
o1 = i
118+
for linear in self.linears:
119+
o1 = linear(o1)
120+
o1 = o1.view(-1, 64, 2)
121+
o1 = o1 + 1
122+
o2 = i
123+
for linear in self.linears:
124+
o2 = linear(o2)
125+
return o1.view(-1, 128) + o2
126+
127+
def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
128+
return (torch.randn(3, 8),)
129+
130+
109131
class ModuleReturnTwo(nn.Module):
110132
def __init__(self) -> None:
111133
super(ModuleReturnTwo, self).__init__()
@@ -360,6 +382,13 @@ def verify_overlap_placeholders(
360382
],
361383
)
362384

385+
test_linear_with_view: Callable[..., None] = maketest(
386+
LinearsWithDifferentSizeAndViewOps,
387+
criteria=[
388+
(greedy, True),
389+
],
390+
)
391+
363392
# greedy algorithm will reuse memory if we let the algorithm allocate
364393
# memory for both graph input and output.
365394
test_list_arg: Callable[..., None] = maketest(
@@ -508,15 +537,26 @@ def test_multiple_pools(
508537
verifier.verify_graph_input_output()
509538

510539
idx = 0
540+
reference_output = dict()
541+
actual_output = dict()
511542
for node in graph_module.graph.nodes:
512543
if node.op == "placeholder" or (
513544
node.op == "call_function"
514545
and node.target in (torch.ops.aten.add.out, torch.ops.aten.mul.out)
515546
):
516547
mem_id, mem_offset = expected_allocs[idx]
517-
self.assertEqual(node.meta["spec"].mem_id, mem_id)
518-
self.assertEqual(node.meta["spec"].mem_offset, mem_offset)
548+
actual_mem_id, actual_mem_offset = (
549+
node.meta["spec"].mem_id,
550+
node.meta["spec"].mem_offset,
551+
)
552+
if (mem_id, mem_offset) not in reference_output:
553+
reference_output[(mem_id, mem_offset)] = 1
554+
actual_output[(actual_mem_id, actual_mem_offset)] = 1
555+
else:
556+
reference_output[(mem_id, mem_offset)] += 1
557+
actual_output[(actual_mem_id, actual_mem_offset)] += 1
519558
idx += 1
559+
self.assertEqual(reference_output, actual_output)
520560
self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes)
521561

522562
def test_constants_not_memory_planned(self) -> None:

0 commit comments

Comments
 (0)