99import collections
1010import itertools
1111import logging
12+ import math
1213import typing
1314from functools import partial
1415from typing import Iterable , List , Optional , Tuple
@@ -39,6 +40,12 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
3940 return memory_config .memory_sizes [exir_id - 1 ]
4041
4142
43+ def get_aligned_offset (pre_aligned_offset : int , alignment : int ) -> int :
44+ if alignment == 0 :
45+ return pre_aligned_offset
46+ return int (math .ceil (pre_aligned_offset / alignment ) * alignment )
47+
48+
4249def collect_specs_from_graph_module (
4350 graph_module : torch .fx .GraphModule ,
4451 alloc_graph_input : bool ,
@@ -95,7 +102,7 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
95102 return None
96103
97104 def memory_available (spec : TensorSpec ) -> bool :
98- return spec .mem_offset + spec .allocated_memory <= get_size (
105+ return get_aligned_offset ( spec .mem_offset + spec .allocated_memory , alignment ) <= get_size (
99106 memory_config , spec .mem_id
100107 )
101108
@@ -116,7 +123,7 @@ def memory_available(spec: TensorSpec) -> bool:
116123 continue
117124 spec .mem_offset = 0
118125 while memory_available (spec ) and (overlapped := overlap (spec )):
119- spec .mem_offset = overlapped .mem_offset + overlapped .allocated_memory
126+ spec .mem_offset = get_aligned_offset ( overlapped .mem_offset + overlapped .allocated_memory , alignment )
120127 if memory_available (spec ):
121128 allocated_buffers [spec .mem_id ].append (spec )
122129 bufsizes [spec .mem_id ] = max (
@@ -202,11 +209,11 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
202209 # calculation of gap incorrect. Moving it out will make the algorithm degenerate
203210 # to the naive one, reusing 0 tensor. The paper may have a typo here.
204211 prev_offset = max (
205- allocated_spec .mem_offset + allocated_spec .allocated_memory ,
212+ get_aligned_offset ( allocated_spec .mem_offset + allocated_spec .allocated_memory , alignment ) ,
206213 prev_offset ,
207214 )
208215 if spec .mem_offset is None :
209- if prev_offset + spec .allocated_memory > get_size (
216+ if get_aligned_offset ( prev_offset + spec .allocated_memory , alignment ) > get_size (
210217 memory_config , spec .mem_id
211218 ):
212219 continue
@@ -423,6 +430,7 @@ def __init__(
423430 ]
424431 ]
425432 ] = None ,
433+ mem_alignment : int = 0 ,
426434 ) -> None :
427435 self ._init_mem_algos ()
428436
@@ -432,6 +440,7 @@ def __init__(
432440 self .alloc_graph_input = alloc_graph_input
433441 self .alloc_graph_output = alloc_graph_output
434442 self .additional_constraint_gen_passes = additional_constraint_gen_passes
443+ self .mem_alignment = mem_alignment
435444
436445 def _init_mem_algos (self ) -> None :
437446 self .available_mem_algos = [
@@ -459,6 +468,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
459468 allow_lifetime_and_storage_overlap = (self .opt_level >= 2 ),
460469 alloc_graph_input = self .alloc_graph_input ,
461470 alloc_graph_output = self .alloc_graph_output ,
471+ alignment = self .mem_alignment ,
462472 )
463473 mem_planning (graph_module )
464474
0 commit comments