@@ -40,6 +40,12 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
4040 return memory_config .memory_sizes [exir_id - 1 ]
4141
4242
43+ def get_alignment (memory_config : MemoryConfig , exir_id : int ) -> int :
44+ # EXIR's spec.mem_id is indexed from 1..N.
45+ assert memory_config .memory_alignments is not None
46+ return memory_config .memory_alignments [exir_id - 1 ]
47+
48+
4349def get_aligned_offset (pre_aligned_offset : int , alignment : int ) -> int :
4450 return int (math .ceil (pre_aligned_offset / alignment ) * alignment )
4551
@@ -84,6 +90,10 @@ def position_based_greedy_with_hierarchy(
8490 ]
8591 ] = None ,
8692) -> List [int ]:
93+ # We do not use the `alignment` parameter and instead use the per-memory alignment
94+ # constraints from `memory_config`.
95+ del alignment
96+
8797 num_memories = get_num_memories (memory_config )
8898 bufsizes = [0 ] * num_memories
8999 allocated_buffers : List [List [TensorSpec ]] = [[] for _ in range (num_memories )]
@@ -103,7 +113,8 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
103113
104114 def memory_available (spec : TensorSpec ) -> bool :
105115 return get_aligned_offset (
106- spec .mem_offset + spec .allocated_memory , alignment
116+ spec .mem_offset + spec .allocated_memory ,
117+ get_alignment (memory_config , spec .mem_id ),
107118 ) <= get_size (memory_config , spec .mem_id )
108119
109120 # Iterate over all the specs in sorted order
@@ -124,7 +135,8 @@ def memory_available(spec: TensorSpec) -> bool:
124135 spec .mem_offset = 0
125136 while memory_available (spec ) and (overlapped := overlap (spec )):
126137 spec .mem_offset = get_aligned_offset (
127- overlapped .mem_offset + overlapped .allocated_memory , alignment
138+ overlapped .mem_offset + overlapped .allocated_memory ,
139+ get_alignment (memory_config , spec .mem_id ),
128140 )
129141 if memory_available (spec ):
130142 allocated_buffers [spec .mem_id ].append (spec )
@@ -172,6 +184,10 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
172184 ]
173185 ] = None ,
174186) -> List [int ]:
187+ # We do not use the `alignment` parameter and instead use the per-memory alignment
188+ # constraints from `memory_config`.
189+ del alignment
190+
175191 num_memories = get_num_memories (memory_config )
176192 bufsizes = [0 ] * num_memories
177193 allocated_buffers = [[] for _ in range (num_memories )]
@@ -213,13 +229,14 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
213229 prev_offset = max (
214230 get_aligned_offset (
215231 allocated_spec .mem_offset + allocated_spec .allocated_memory ,
216- alignment ,
232+ get_alignment ( memory_config , spec . mem_id ) ,
217233 ),
218234 prev_offset ,
219235 )
220236 if spec .mem_offset is None :
221237 if get_aligned_offset (
222- prev_offset + spec .allocated_memory , alignment
238+ prev_offset + spec .allocated_memory ,
239+ get_alignment (memory_config , spec .mem_id ),
223240 ) > get_size (memory_config , spec .mem_id ):
224241 continue
225242 else :
@@ -439,7 +456,6 @@ def __init__(
439456 ]
440457 ]
441458 ] = None ,
442- mem_alignment : int = 1 ,
443459 ) -> None :
444460 self ._init_mem_algos ()
445461
@@ -450,9 +466,6 @@ def __init__(
450466 self .alloc_graph_output = alloc_graph_output
451467 self .additional_constraint_gen_passes = additional_constraint_gen_passes
452468
453- assert mem_alignment > 0 , "mem_alignment must be positive"
454- self .mem_alignment = mem_alignment
455-
456469 def _init_mem_algos (self ) -> None :
457470 self .available_mem_algos = [
458471 position_based_greedy_with_hierarchy ,
@@ -489,7 +502,6 @@ def run(
489502 allow_lifetime_and_storage_overlap = (self .opt_level >= 2 ),
490503 alloc_graph_input = self .alloc_graph_input ,
491504 alloc_graph_output = self .alloc_graph_output ,
492- alignment = self .mem_alignment ,
493505 )
494506 mem_planning .run (graph_module , graph_signature )
495507
0 commit comments