1111import operator
1212import typing
1313from collections import defaultdict
14- from dataclasses import dataclass
14+ from dataclasses import dataclass , field
1515from typing import Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union
1616
1717import torch
@@ -454,6 +454,18 @@ def update_all_tensors_lifetime(
454454 return specs
455455
456456
457+ @dataclass
458+ class AllocationSpec :
459+ """
460+ AllocationSpec is used to represent the allocation of a tensor.
461+ """
462+
463+ # The offset of the tensor in the shared object/pool.
464+ offset : int
465+ # TensorSpec
466+ spec : TensorSpec
467+
468+
457469@dataclass
458470class SharedObject :
459471 r"""
@@ -470,8 +482,15 @@ class SharedObject:
470482 offset : int
471483 # size of this shared object in bytes
472484 size : int
485+ # When the object is first created
486+ first_used_index : int
473487 # the object will be available for index (last_used_index + 1)
474488 last_used_index : int
489+ # list of allocations belong to this shared object
490+ allocations : List [AllocationSpec ] = field (default_factory = list )
491+
492+ def __repr__ (self ) -> str :
493+ return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
475494
476495
477496def materialize_buffer (
@@ -489,35 +508,122 @@ def materialize_buffer(
489508 return total_size
490509
491510
492- def _size_abs_dif (sobj : SharedObject , spec : TensorSpec ) -> int :
511+ def _does_not_overlap (sobj : SharedObject , spec : TensorSpec ) -> bool :
493512 r"""
494- Calculate the absolute different between the size of a shared object and
495- a tensor.
513+ Check if a shared object and a tensor do not overlap.
496514 """
497- return abs (sobj .size - spec .allocated_memory )
515+ for alloc in sobj .allocations :
516+ if not (
517+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
518+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
519+ ):
520+ return False
521+ return True
522+
523+
524+ def _find_max_overlapping_allocations_offset (
525+ sobj : SharedObject , spec : TensorSpec
526+ ) -> int :
527+ max_offset = 0
528+ for alloc in sobj .allocations :
529+ if (
530+ spec .lifetime [1 ] < alloc .spec .lifetime [0 ]
531+ or spec .lifetime [0 ] > alloc .spec .lifetime [1 ]
532+ ):
533+ continue
534+ max_offset = max (alloc .offset + alloc .spec .allocated_memory , max_offset )
535+ return max_offset
498536
499537
500538def pick_shared_obj (
501539 shared_objects : List [SharedObject ], spec : TensorSpec
502540) -> SharedObject :
503541 r"""
504- Pick the available shared object with closest size to the tensor.
505- If there are no available shared object left, create a new one.
542+ Pick the available shared object to which to assign this spec,
543+ or create a new one
544+ Algorithm details
545+ Previous: Look at every spec in chronological order. Find if previously allocated object
546+ allows it to fit in. If not, allocate a new object.
547+ New:
548+ - Sort all the specs by allocation size
549+ - Process the specs in order
550+ - If the spec's size in smaller than previously allocated buckets:
551+ - Conditions under which previously allocated bucket can be used:
552+ - Lifetime of the spec does not overlap with lifetime of the bucket.
553+ - In this case allocate spec to that bucket and expand its lifetime.
554+ - Spec is allocated at offset = 0 in this bucket.
555+ - Add this spec to allocated object's list of specs.
556+ - Lifetime of the spec overlaps with lifetime of the bucket,
557+ partially or fully (e.g. spec's lifetime subset of bucket's lifetime)
558+ - If none of the specs in the bucket overlaps with spec's lifetime.
559+ - Allocate spec to the bucket at offset = 0.
560+ - Add this spec to the bucket's list of specs.
561+ - Expand bucket's lifetime accounting for added spec's lifetime.
562+ - If one or more specs in the bucket overlaps with spec's lifetime.
563+ - Collect offsets (at which the given overlapping spec is allocated in the bucket).
564+ of all the overlapping specs, and find the max offset.
565+ - Allocate spec to the bucket at offset = max_offset + max_offset_spec_size.
566+ - Add this spec to the bucket's list of specs.
567+ - Expand bucket's lifetime accounting for added spec's lifetime.
568+ - If none of these conditions are met, allocate a new bucket.
569+ - Add spec to this bucket.
570+ - Update bucket's lifetime to that of the spec.
571+ - If the spec's size is larger than previously allocated buckets, allocate a new bucket.
572+ - Size and lifetime of this bucket is that of the spec
573+
574+ Proof of correctness:
575+ - If allocating a new bucket, it is correct.
576+ - If allocating spec to an existing bucket, whose lifetime does not overlap with any
577+ of the previously allocated specs' lifetime, then the allocation is correct.
578+ Proof of correctness by induction when adding spec to an existing bucket:
579+ - If all previous allocations in the given bucket are correct:
580+ - Then the new one being added must be correct because when the requested allocation
581+ overlaps with one or more previous allocations, we find the largest offset among
582+ all the overlapping allocations, and allocate the new spec at that offset. Hence,
583+ the allocation at such an offset, will not overlap with any previous allocations.
584+ Base case: A newly added allocation within a bucket with single allocation is correct:
585+ because a) it must fit and b) its lifetime must not overlap with object's lifetime.
586+ This holds true because of the following invariants:
587+ - Once a bucket is created, it is never resized.
588+ - All the allocations within a bucket follow this:
589+ - Span, defined by allocation's offset + size, of two allocations can only overlap,
590+ if their timelines do not overlap.
506591 """
507- # TODO: do better than linear scan
508592 picked = None
509593 for sobj in shared_objects :
510- if spec .lifetime [0 ] > sobj .last_used_index :
511- if picked is None or _size_abs_dif (sobj , spec ) < _size_abs_dif (
512- picked , spec
513- ):
514- picked = sobj
515- sobj .last_used_index = spec .lifetime [1 ]
516- sobj .size = max (sobj .size , spec .allocated_memory )
594+ if _does_not_overlap (sobj , spec ):
595+ assert sobj .size >= spec .allocated_memory , "Allocation specs are not sorted"
596+ picked = sobj
597+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
598+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
599+ allocation_spec = AllocationSpec (0 , spec )
600+ picked .allocations .append (allocation_spec )
601+ break
602+
603+ if picked is None :
604+ for sobj in shared_objects :
605+ max_offset = _find_max_overlapping_allocations_offset (sobj , spec )
606+ if max_offset > 0 :
607+ if max_offset + spec .allocated_memory <= sobj .size :
608+ picked = sobj
609+ sobj .first_used_index = min (sobj .first_used_index , spec .lifetime [0 ])
610+ sobj .last_used_index = max (sobj .last_used_index , spec .lifetime [1 ])
611+ allocation_spec = AllocationSpec (max_offset , spec )
612+ picked .allocations .append (allocation_spec )
613+ break
614+
517615 if picked is None :
518616 picked = SharedObject (
519- len (shared_objects ), - 1 , spec .allocated_memory , spec .lifetime [1 ]
617+ len (shared_objects ),
618+ - 1 ,
619+ spec .allocated_memory ,
620+ spec .lifetime [0 ],
621+ spec .lifetime [1 ],
520622 )
623+ allocation_spec = AllocationSpec (0 , spec )
624+ picked .allocations .append (allocation_spec )
625+ picked .first_used_index = spec .lifetime [0 ]
626+ picked .last_used_index = spec .lifetime [1 ]
521627 shared_objects .append (picked )
522628
523629 return picked
@@ -565,13 +671,20 @@ def greedy(
565671 # For each tensor, pick the available shared object with closest size to
566672 # the tensor. If there are no available shared object left, create a new
567673 # one.
674+ import bisect
675+
676+ sorted_specs = []
568677 for spec in collect_specs_from_nodes (
569678 graph_module .graph .nodes ,
570679 graph_signature ,
571680 do_assertion = do_assertion ,
572681 ignore_graph_input = not alloc_graph_input ,
573682 ignore_graph_output = not alloc_graph_output ,
574683 ):
684+ bisect .insort (sorted_specs , spec , key = lambda x : x .allocated_memory )
685+ sorted_specs .reverse ()
686+
687+ for spec in sorted_specs :
575688 if spec .mem_id is None :
576689 spec .mem_id = 1
577690 spec .realign (alignment )
@@ -583,6 +696,7 @@ def greedy(
583696 total_sizes = [0 , 0 ]
584697 else :
585698 total_sizes = [0 ] * (max (shared_objects .keys ()) + 1 )
699+ num_specs_processed = 0
586700 for mem_id in shared_objects :
587701 input_total_size = 0
588702 if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
@@ -594,13 +708,18 @@ def greedy(
594708 total_sizes [mem_id ] = materialize_buffer (
595709 shared_objects [mem_id ], input_total_size
596710 )
597-
598- # Since we now know the number of shared objects we need and the size of
599- # each shared object, we can assign offset in the memory buffer for each
600- # shared object.
601- for spec , sobj in spec2obj .items ():
602- spec .mem_obj_id = sobj .idx
603- spec .mem_offset = sobj .offset
711+ # Since we now know the number of shared objects we need and the size of
712+ # each shared object, we can assign offset in the memory buffer for each
713+ # shared object.
714+ for sobj in shared_objects [mem_id ]:
715+ for alloc in sobj .allocations :
716+ spec = alloc .spec
717+ alloc .spec .mem_obj_id = sobj .idx
718+ alloc .spec .mem_offset = sobj .offset + alloc .offset
719+ num_specs_processed += 1
720+ assert (
721+ len (spec2obj ) == num_specs_processed
722+ ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
604723
605724 logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
606725 return total_sizes
0 commit comments