1111import operator
1212import typing
1313from collections import defaultdict
14- from dataclasses import dataclass
14+ from dataclasses import dataclass , field
1515from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union
1616
1717import torch
@@ -117,6 +117,17 @@ def storage_overlap(cls, lhs_spec: TensorSpec, rhs_spec: TensorSpec) -> bool:
117117
118118 return has_overlap
119119
120+ @classmethod
121+ def _debug_message_from_specs (
122+ cls , lhs_spec : TensorSpec , rhs_spec : TensorSpec
123+ ) -> str :
124+ message = (
125+ f"lhs life time: { lhs_spec .lifetime } , rhs lifetime: { rhs_spec .lifetime } "
126+ )
127+ message += f"lhs: mem_id { lhs_spec .mem_id } storage: { lhs_spec .mem_offset } , { lhs_spec .allocated_memory } "
128+ message += f"rhs: mem_id { rhs_spec .mem_id } storage: { rhs_spec .mem_offset } , { rhs_spec .allocated_memory } "
129+ return message
130+
120131 def verify_storage_reuse (
121132 self , allow_lifetime_and_storage_overlap : bool = False
122133 ) -> int :
@@ -159,7 +170,7 @@ def verify_storage_reuse(
159170 lhs_spec , rhs_spec
160171 ):
161172 raise InternalError (
162- f"Unexpected storage overlap: lhs { lhs_spec } , rhs { rhs_spec } "
173+ f"Unexpected storage overlap: { Verifier . _debug_message_from_specs ( lhs_spec , rhs_spec ) } "
163174 )
164175
165176 # Check that each mem_obj_id is consistent with whether the tensors have
@@ -454,6 +465,18 @@ def update_all_tensors_lifetime(
454465 return specs
455466
456467
468+ @dataclass
469+ class AllocationSpec :
470+ """
471+ AllocationSpec is used to represent the allocation of a tensor.
472+ """
473+
474+ # The offset of the tensor in the shared object/pool.
475+ offset : int
476+ # TensorSpec
477+ spec : TensorSpec
478+
479+
457480@dataclass
458481class SharedObject :
459482 r"""
@@ -470,8 +493,15 @@ class SharedObject:
470493 offset : int
471494 # size of this shared object in bytes
472495 size : int
496+ # When the object is first created
497+ first_used_index : int
473498 # the object will be available for index (last_used_index + 1)
474499 last_used_index : int
500+ # list of allocations belong to this shared object
501+ allocations : List [AllocationSpec ] = field (default_factory = list )
502+
503+ def __repr__ (self ) -> str :
504+ return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
475505
476506
477507def materialize_buffer (
@@ -489,35 +519,122 @@ def materialize_buffer(
489519 return total_size
490520
491521
492- def _size_abs_dif (sobj : SharedObject , spec : TensorSpec ) -> int :
522+ def _does_not_overlap (sobj : SharedObject , spec : TensorSpec ) -> bool :
493523 r"""
494- Calculate the absolute different between the size of a shared object and
495- a tensor.
524+ Check if a shared object and a tensor do not overlap.
496525 """
497- return abs (sobj .size - spec .allocated_memory )
526+ for alloc in sobj .allocations :
527+ if not (
528+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
529+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
530+ ):
531+ return False
532+ return True
533+
534+
535+ def _find_max_overlapping_allocations_offset (
536+ sobj : SharedObject , spec : TensorSpec
537+ ) -> int :
538+ max_offset = 0
539+ for alloc in sobj .allocations :
540+ if (
541+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
542+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
543+ ):
544+ continue
545+ max_offset = max (alloc .offset + alloc .spec .allocated_memory , max_offset )
546+ return max_offset
498547
499548
500549def pick_shared_obj (
501550 shared_objects : List [SharedObject ], spec : TensorSpec
502551) -> SharedObject :
503552 r"""
504- Pick the available shared object with closest size to the tensor.
505- If there are no available shared object left, create a new one.
553+ Pick the available shared object to which to assign this spec,
554+ or create a new one
555+ Algorithm details
556+ Previous: Look at every spec in chronological order. Find if previously allocated object
557+ allows it to fit in. If not, allocate a new object.
558+ New:
559+ - Sort all the specs by allocation size
560+ - Process the specs in order
561+ - If the spec's size in smaller than previously allocated buckets:
562+ - Conditions under which previously allocated bucket can be used:
563+ - Lifetime of the spec does not overlap with lifetime of the bucket.
564+ - In this case allocate spec to that bucket and expand its lifetime.
565+ - Spec is allocated at offset = 0 in this bucket.
566+ - Add this spec to allocated object's list of specs.
567+ - Lifetime of the spec overlaps with lifetime of the bucket,
568+ partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
569+ - If none of the specs in the bucket overlaps with spec's lifetime.
570+ - Allocate spec to the bucket at offset = 0.
571+ - Add this spec to the bucket's list of specs.
572+ - Expand bucket's lifetime accounting for added spec's lifetime.
573+ - If one or more specs in the bucket overlaps with spec's lifetime.
574+ - Collect offsets (at which the given overlapping spec is allocated in the bucket).
575+ of all the overlapping specs, and find the max offset.
576+ - Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
577+ - Add this spec to the bucket's list of specs.
578+ - Expand bucket's lifetime accounting for added spec's lifetime.
579+ - If none of these conditions are met, allocate a new bucket.
580+ - Add spec to this bucket.
581+ - Update bucket's lifetime to that of the spec.
582+ - If the spec's size is larger than previously allocated buckets, allocate a new bucket.
583+ - Size and lifetime of this bucket is that of the spec
584+
585+ Proof of correctness:
586+ - If allocating a new bucket, it is correct.
587+ - If allocating spec to an existing bucket, whose lifetime does not overlap with any
588+ of the previously allocated specs' lifetime, then the allocation is correct.
589+ Proof of correctness by induction when adding spec to an existing bucket:
590+ - If all previous allocations in the given bucket are correct:
591+ - Then the new one being added must be correct because when the requested allocation
592+ overlaps with one or more previous allocations, we find the largest offset among
593+ all the overlapping allocations, and allocate the new spec at that offset. Hence,
594+ the allocation at such an offset, will not overlap with any previous allocations.
595+ Base case: A newly added allocation within a bucket with single allocation is correct:
596+ because a) it must fit and b) its lifetime must not overlap with object's lifetime.
597+ This holds true because of the following invariants:
598+ - Once a bucket is created, it is never resized.
599+ - All the allocations within a bucket follow this:
600+ - Span, defined by allocation's offset + size, of two allocations can only overlap,
601+ if their timelines do not overlap.
506602 """
507- # TODO: do better than linear scan
508603 picked = None
509604 for sobj in shared_objects :
510- if spec .lifetime [0 ] > sobj .last_used_index :
511- if picked is None or _size_abs_dif (sobj , spec ) < _size_abs_dif (
512- picked , spec
513- ):
514- picked = sobj
515- sobj .last_used_index = spec .lifetime [1 ]
516- sobj .size = max (sobj .size , spec .allocated_memory )
605+ if _does_not_overlap (sobj , spec ):
606+ assert sobj .size >= spec .allocated_memory , "Allocation specs are not sorted"
607+ picked = sobj
608+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
609+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
610+ allocation_spec = AllocationSpec (0 , spec )
611+ picked .allocations .append (allocation_spec )
612+ break
613+
614+ if picked is None :
615+ for sobj in shared_objects :
616+ max_offset = _find_max_overlapping_allocations_offset (sobj , spec )
617+ if max_offset > 0 :
618+ if max_offset + spec .allocated_memory <= sobj .size :
619+ picked = sobj
620+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
621+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
622+ allocation_spec = AllocationSpec (max_offset , spec )
623+ picked .allocations .append (allocation_spec )
624+ break
625+
517626 if picked is None :
518627 picked = SharedObject (
519- len (shared_objects ), - 1 , spec .allocated_memory , spec .lifetime [1 ]
628+ len (shared_objects ),
629+ - 1 ,
630+ spec .allocated_memory ,
631+ spec .lifetime [0 ],
632+ spec .lifetime [1 ],
520633 )
634+ allocation_spec = AllocationSpec (0 , spec )
635+ picked .allocations .append (allocation_spec )
636+ picked .first_used_index = spec .lifetime [0 ]
637+ picked .last_used_index = spec .lifetime [1 ]
521638 shared_objects .append (picked )
522639
523640 return picked
@@ -565,13 +682,20 @@ def greedy(
565682 # For each tensor, pick the available shared object with closest size to
566683 # the tensor. If there are no available shared object left, create a new
567684 # one.
685+ import bisect
686+
687+ sorted_specs = []
568688 for spec in collect_specs_from_nodes (
569689 graph_module .graph .nodes ,
570690 graph_signature ,
571691 do_assertion = do_assertion ,
572692 ignore_graph_input = not alloc_graph_input ,
573693 ignore_graph_output = not alloc_graph_output ,
574694 ):
695+ bisect .insort (sorted_specs , spec , key = lambda x : x .allocated_memory )
696+ sorted_specs .reverse ()
697+
698+ for spec in sorted_specs :
575699 if spec .mem_id is None :
576700 spec .mem_id = 1
577701 spec .realign (alignment )
@@ -583,6 +707,7 @@ def greedy(
583707 total_sizes = [0 , 0 ]
584708 else :
585709 total_sizes = [0 ] * (max (shared_objects .keys ()) + 1 )
710+ num_specs_processed = 0
586711 for mem_id in shared_objects :
587712 input_total_size = 0
588713 if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
@@ -594,13 +719,25 @@ def greedy(
594719 total_sizes [mem_id ] = materialize_buffer (
595720 shared_objects [mem_id ], input_total_size
596721 )
597-
598- # Since we now know the number of shared objects we need and the size of
599- # each shared object, we can assign offset in the memory buffer for each
600- # shared object.
601- for spec , sobj in spec2obj .items ():
602- spec .mem_obj_id = sobj .idx
603- spec .mem_offset = sobj .offset
722+ # padding allocation with 64 bytes.
723+ # this requirement really for XNNPACK backend which can access tensors
724+ # for reading beyond the end of the tensor. This is done for performance
725+ # optimizations in XNNPACK.
726+ # While account for backend specific requirement is not the right choice
727+ # in backend agnostic memory planning, we do it here for now.
728+ total_sizes [mem_id ] += 64
729+ # Since we now know the number of shared objects we need and the size of
730+ # each shared object, we can assign offset in the memory buffer for each
731+ # shared object.
732+ for sobj in shared_objects [mem_id ]:
733+ for alloc in sobj .allocations :
734+ spec = alloc .spec
735+ alloc .spec .mem_obj_id = sobj .idx
736+ alloc .spec .mem_offset = sobj .offset + alloc .offset
737+ num_specs_processed += 1
738+ assert (
739+ len (spec2obj ) == num_specs_processed
740+ ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
604741
605742 logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
606743 return total_sizes
0 commit comments