1111import operator
1212import typing
1313from collections import defaultdict
14- from dataclasses import dataclass
14+ from dataclasses import dataclass , field
1515from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union
1616
1717import torch
@@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
117117
118118 return has_overlap
119119
120+ @classmethod
121+ def _debug_message_from_specs (
122+ cls , lhs_spec : TensorSpec , rhs_spec : TensorSpec
123+ ) -> str :
124+ message = (
125+ f"lhs life time: { lhs_spec .lifetime } , rhs lifetime: { rhs_spec .lifetime } "
126+ )
127+ message += f"lhs: mem_id { lhs_spec .mem_id } storage: { lhs_spec .mem_offset } , { lhs_spec .allocated_memory } "
128+ message += f"rhs: mem_id { rhs_spec .mem_id } storage: { rhs_spec .mem_offset } , { rhs_spec .allocated_memory } "
129+ return message
130+
120131 def verify_storage_reuse (
121132 self , allow_lifetime_and_storage_overlap : bool = False
122133 ) -> int :
@@ -159,7 +170,7 @@ def verify_storage_reuse(
159170 lhs_spec , rhs_spec
160171 ):
161172 raise InternalError (
162- f"Unexpected storage overlap: lhs { lhs_spec } , rhs { rhs_spec } "
173+ f"Unexpected storage overlap: { Verifier . _debug_message_from_specs ( lhs_spec , rhs_spec ) } "
163174 )
164175
165176 # Check that each mem_obj_id is consistent with whether the tensors have
@@ -454,6 +465,18 @@ def update_all_tensors_lifetime(
454465 return specs
455466
456467
468+ @dataclass
469+ class AllocationSpec :
470+ """
471+ AllocationSpec is used to represent the allocation of a tensor.
472+ """
473+
474+ # The offset of the tensor in the shared object/pool.
475+ offset : int
476+ # TensorSpec
477+ spec : TensorSpec
478+
479+
457480@dataclass
458481class SharedObject :
459482 r"""
@@ -470,8 +493,15 @@ class SharedObject:
470493 offset : int
471494 # size of this shared object in bytes
472495 size : int
496+ # When the object is first created
497+ first_used_index : int
473498 # the object will be available for index (last_used_index + 1)
474499 last_used_index : int
500+ # list of allocations belong to this shared object
501+ allocations : List [AllocationSpec ] = field (default_factory = list )
502+
503+ def __repr__ (self ) -> str :
504+ return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
475505
476506
477507def materialize_buffer (
@@ -489,35 +519,124 @@ def materialize_buffer(
489519 return total_size
490520
491521
492- def _size_abs_dif (sobj : SharedObject , spec : TensorSpec ) -> int :
522+ def _does_not_overlap (sobj : SharedObject , spec : TensorSpec ) -> bool :
493523 r"""
494- Calculate the absolute different between the size of a shared object and
495- a tensor.
524+ Check if a shared object and a tensor do not overlap.
496525 """
497- return abs (sobj .size - spec .allocated_memory )
526+ for alloc in sobj .allocations :
527+ if not (
528+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
529+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
530+ ):
531+ return False
532+ return True
533+
534+
535+ def _find_max_overlapping_allocations_offset (
536+ sobj : SharedObject , spec : TensorSpec
537+ ) -> int :
538+ max_offset = 0
539+ for alloc in sobj .allocations :
540+ if (
541+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
542+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
543+ ):
544+ continue
545+ max_offset = max (alloc .offset + alloc .spec .allocated_memory , max_offset )
546+ return max_offset
498547
499548
500549def pick_shared_obj (
501- shared_objects : List [SharedObject ], spec : TensorSpec
550+ shared_objects : List [SharedObject ],
551+ spec : TensorSpec ,
552+ allow_overlapping_allocations : bool = True ,
502553) -> SharedObject :
503554 r"""
504- Pick the available shared object with closest size to the tensor.
505- If there are no available shared object left, create a new one.
555+ Pick the available shared object to which to assign this spec,
556+ or create a new one
557+ Algorithm details
558+ Previous: Look at every spec in chronological order. Find if previously allocated object
559+ allows it to fit in. If not, allocate a new object.
560+ New:
561+ - Sort all the specs by allocation size
562+ - Process the specs in order
563+ - If the spec's size in smaller than previously allocated buckets:
564+ - Conditions under which previously allocated bucket can be used:
565+ - Lifetime of the spec does not overlap with lifetime of the bucket.
566+ - In this case allocate spec to that bucket and expand its lifetime.
567+ - Spec is allocated at offset = 0 in this bucket.
568+ - Add this spec to allocated object's list of specs.
569+ - Lifetime of the spec overlaps with lifetime of the bucket,
570+ partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
571+ - If none of the specs in the bucket overlaps with spec's lifetime.
572+ - Allocate spec to the bucket at offset = 0.
573+ - Add this spec to the bucket's list of specs.
574+ - Expand bucket's lifetime accounting for added spec's lifetime.
575+ - If one or more specs in the bucket overlaps with spec's lifetime.
576+ - Collect offsets (at which the given overlapping spec is allocated in the bucket).
577+ of all the overlapping specs, and find the max offset.
578+ - Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
579+ - Add this spec to the bucket's list of specs.
580+ - Expand bucket's lifetime accounting for added spec's lifetime.
581+ - If none of these conditions are met, allocate a new bucket.
582+ - Add spec to this bucket.
583+ - Update bucket's lifetime to that of the spec.
584+ - If the spec's size is larger than previously allocated buckets, allocate a new bucket.
585+ - Size and lifetime of this bucket is that of the spec
586+
587+ Proof of correctness:
588+ - If allocating a new bucket, it is correct.
589+ - If allocating spec to an existing bucket, whose lifetime does not overlap with any
590+ of the previously allocated specs' lifetime, then the allocation is correct.
591+ Proof of correctness by induction when adding spec to an existing bucket:
592+ - If all previous allocations in the given bucket are correct:
593+ - Then the new one being added must be correct because when the requested allocation
594+ overlaps with one or more previous allocations, we find the largest offset among
595+ all the overlapping allocations, and allocate the new spec at that offset. Hence,
596+ the allocation at such an offset, will not overlap with any previous allocations.
597+ Base case: A newly added allocation within a bucket with single allocation is correct:
598+ because a) it must fit and b) its lifetime must not overlap with object's lifetime.
599+ This holds true because of the following invariants:
600+ - Once a bucket is created, it is never resized.
601+ - All the allocations within a bucket follow this:
602+ - Span, defined by allocation's offset + size, of two allocations can only overlap,
603+ if their timelines do not overlap.
506604 """
507- # TODO: do better than linear scan
508605 picked = None
509606 for sobj in shared_objects :
510- if spec .lifetime [0 ] > sobj .last_used_index :
511- if picked is None or _size_abs_dif (sobj , spec ) < _size_abs_dif (
512- picked , spec
513- ):
514- picked = sobj
515- sobj .last_used_index = spec .lifetime [1 ]
516- sobj .size = max (sobj .size , spec .allocated_memory )
607+ if _does_not_overlap (sobj , spec ):
608+ assert sobj .size >= spec .allocated_memory , "Allocation specs are not sorted"
609+ picked = sobj
610+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
611+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
612+ allocation_spec = AllocationSpec (0 , spec )
613+ picked .allocations .append (allocation_spec )
614+ break
615+
616+ if picked is None and allow_overlapping_allocations :
617+ for sobj in shared_objects :
618+ max_offset = _find_max_overlapping_allocations_offset (sobj , spec )
619+ if max_offset > 0 :
620+ if max_offset + spec .allocated_memory <= sobj .size :
621+ picked = sobj
622+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
623+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
624+ allocation_spec = AllocationSpec (max_offset , spec )
625+ picked .allocations .append (allocation_spec )
626+ break
627+
517628 if picked is None :
518629 picked = SharedObject (
519- len (shared_objects ), - 1 , spec .allocated_memory , spec .lifetime [1 ]
630+ len (shared_objects ),
631+ - 1 ,
632+ spec .allocated_memory ,
633+ spec .lifetime [0 ],
634+ spec .lifetime [1 ],
520635 )
636+ allocation_spec = AllocationSpec (0 , spec )
637+ picked .allocations .append (allocation_spec )
638+ picked .first_used_index = spec .lifetime [0 ]
639+ picked .last_used_index = spec .lifetime [1 ]
521640 shared_objects .append (picked )
522641
523642 return picked
@@ -550,13 +669,50 @@ def get_node_tensor_specs(
550669 ]
551670
552671
672+ # Little bit hacky to check if the graph contains
673+ # XNNPACK delegate
674+ # Why?
675+
676+
677+ def _contains_xnnpack_delegate (graph_module : torch .fx .GraphModule ) -> bool :
678+ for node in graph_module .graph .nodes :
679+ if node .target == executorch_call_delegate :
680+ lowered_module = getattr (
681+ graph_module .graph .owning_module , node .args [0 ].target
682+ )
683+ if "xnnpack" in lowered_module .backend_id .lower ():
684+ return True
685+ return False
686+
687+
553688def greedy (
554689 graph_module : torch .fx .GraphModule ,
555690 alignment : int ,
556691 graph_signature : Optional [ExportGraphSignature ] = None ,
557692 alloc_graph_input : bool = True ,
558693 alloc_graph_output : bool = True ,
694+ allow_overlapping_allocations : bool = True ,
559695) -> List [int ]:
696+ r"""Greedy algorithm to allocate memory for tensors in the graph.
697+ alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698+ alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
699+ allow_overlapping_allocations: If set to true, allows for allocations that overlap
700+ in their lifetime but are at different offsets in the storage. By default true.
701+ This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702+ allocations disabled
703+ """
704+ # padding allocation with 64 bytes.
705+ # this requirement is really for XNNPACK backend which can read tensors
706+ # beyond the end of the tensor. This is done for performance
707+ # optimizations in XNNPACK.
708+ # While accounting for backend specific requirement is not the right choice
709+ # in backend agnostic memory planning, we do it here as it seems most appropriate.
710+ # Right now this applies to greedy only so any other
711+ # algorithm that plans memory for XNNPACK backend will
712+ # not have this.
713+ extra_padded_bytes = 0
714+ if _contains_xnnpack_delegate (graph_module ):
715+ extra_padded_bytes = 64
560716 spec2obj = {}
561717 shared_objects = defaultdict (list )
562718 # Don't do assertion in collect_specs_from_nodes if we have already encountered
@@ -565,24 +721,34 @@ def greedy(
565721 # For each tensor, pick the available shared object with closest size to
566722 # the tensor. If there are no available shared object left, create a new
567723 # one.
724+ import bisect
725+
726+ sorted_specs = []
568727 for spec in collect_specs_from_nodes (
569728 graph_module .graph .nodes ,
570729 graph_signature ,
571730 do_assertion = do_assertion ,
572731 ignore_graph_input = not alloc_graph_input ,
573732 ignore_graph_output = not alloc_graph_output ,
574733 ):
734+ bisect .insort (sorted_specs , spec , key = lambda x : x .allocated_memory )
735+ sorted_specs .reverse ()
736+
737+ for spec in sorted_specs :
575738 if spec .mem_id is None :
576739 spec .mem_id = 1
577740 spec .realign (alignment )
578- spec2obj [spec ] = pick_shared_obj (shared_objects [spec .mem_id ], spec )
741+ spec2obj [spec ] = pick_shared_obj (
742+ shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
743+ )
579744
580745 if len (shared_objects ) == 0 :
581746 # Cannot find any tensor in the graph that needs to be allocated.
582747 # Return [0, 0] to be consistent with default behavior of naive.
583748 total_sizes = [0 , 0 ]
584749 else :
585750 total_sizes = [0 ] * (max (shared_objects .keys ()) + 1 )
751+ num_specs_processed = 0
586752 for mem_id in shared_objects :
587753 input_total_size = 0
588754 if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
@@ -594,13 +760,20 @@ def greedy(
594760 total_sizes [mem_id ] = materialize_buffer (
595761 shared_objects [mem_id ], input_total_size
596762 )
597-
598- # Since we now know the number of shared objects we need and the size of
599- # each shared object, we can assign offset in the memory buffer for each
600- # shared object.
601- for spec , sobj in spec2obj .items ():
602- spec .mem_obj_id = sobj .idx
603- spec .mem_offset = sobj .offset
763+ total_sizes [mem_id ] += extra_padded_bytes
764+
765+ # Since we now know the number of shared objects we need and the size of
766+ # each shared object, we can assign offset in the memory buffer for each
767+ # shared object.
768+ for sobj in shared_objects [mem_id ]:
769+ for alloc in sobj .allocations :
770+ spec = alloc .spec
771+ alloc .spec .mem_obj_id = sobj .idx
772+ alloc .spec .mem_offset = sobj .offset + alloc .offset
773+ num_specs_processed += 1
774+ assert (
775+ len (spec2obj ) == num_specs_processed
776+ ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
604777
605778 logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
606779 return total_sizes
0 commit comments