1111import  operator 
1212import  typing 
1313from  collections  import  defaultdict 
14- from  dataclasses  import  dataclass 
14+ from  dataclasses  import  dataclass ,  field 
1515from  typing  import  Any , Callable , Dict , Iterable , List , Optional , Set , Tuple , Union 
1616
1717import  torch 
@@ -454,6 +454,18 @@ def update_all_tensors_lifetime(
454454    return  specs 
455455
456456
457+ @dataclass  
458+ class  AllocationSpec :
459+     """ 
460+     AllocationSpec is used to represent the allocation of a tensor. 
461+     """ 
462+ 
463+     # The offset of the tensor in the shared object/pool. 
464+     offset : int 
465+     # TensorSpec 
466+     spec : TensorSpec 
467+ 
468+ 
457469@dataclass  
458470class  SharedObject :
459471    r""" 
@@ -470,8 +482,15 @@ class SharedObject:
470482    offset : int 
471483    # size of this shared object in bytes 
472484    size : int 
485+     # When the object is first created 
486+     first_used_index : int 
473487    # the object will be available for index (last_used_index + 1) 
474488    last_used_index : int 
489+     # list of allocations belong to this shared object 
490+     allocations : List [AllocationSpec ] =  field (default_factory = list )
491+ 
492+     def  __repr__ (self ) ->  str :
493+         return  f"SharedObject(idx={ self .idx }  , offset={ self .offset }  , size={ self .size }  , lifetime=[{ self .first_used_index , self .last_used_index }  ])" 
475494
476495
477496def  materialize_buffer (
@@ -489,35 +508,122 @@ def materialize_buffer(
489508    return  total_size 
490509
491510
492- def  _size_abs_dif (sobj : SharedObject , spec : TensorSpec ) ->  int :
511+ def  _does_not_overlap (sobj : SharedObject , spec : TensorSpec ) ->  bool :
493512    r""" 
494-     Calculate the absolute different between the size of a shared object and 
495-     a tensor. 
513+     Check if a shared object and a tensor do not overlap. 
496514    """ 
497-     return  abs (sobj .size  -  spec .allocated_memory )
515+     for  alloc  in  sobj .allocations :
516+         if  not  (
517+             spec .lifetime [1 ] <  alloc .spec .lifetime [0 ]
518+             or  spec .lifetime [0 ] >  alloc .spec .lifetime [1 ]
519+         ):
520+             return  False 
521+     return  True 
522+ 
523+ 
524+ def  _find_max_overlapping_allocations_offset (
525+     sobj : SharedObject , spec : TensorSpec 
526+ ) ->  int :
527+     max_offset  =  0 
528+     for  alloc  in  sobj .allocations :
529+         if  (
530+             spec .lifetime [1 ] <  alloc .spec .lifetime [0 ]
531+             or  spec .lifetime [0 ] >  alloc .spec .lifetime [1 ]
532+         ):
533+             continue 
534+         max_offset  =  max (alloc .offset  +  alloc .spec .allocated_memory , max_offset )
535+     return  max_offset 
498536
499537
500538def  pick_shared_obj (
501539    shared_objects : List [SharedObject ], spec : TensorSpec 
502540) ->  SharedObject :
503541    r""" 
504-     Pick the available shared object with closest size to the tensor. 
505-     If there are no available shared object left, create a new one. 
542+     Pick the available shared object to which to assign this spec, 
543+     or create a new one 
544+     Algorithm details 
545+     Previous: Look at every spec in chronological order. Find if previously allocated object 
546+     allows it to fit in. If not, allocate a new object. 
547+     New: 
548+     - Sort all the specs by allocation size 
549+     - Process the specs in order 
550+     - If the spec's size in smaller than previously allocated buckets: 
551+         - Conditions under which previously allocated bucket can be used: 
552+           - Lifetime of the spec does not overlap with lifetime of the bucket. 
553+               - In this case allocate spec to that bucket and expand its lifetime. 
554+               - Spec is allocated at offset = 0 in this bucket. 
555+               - Add this spec to allocated object's list of specs. 
556+           - Lifetime of the spec overlaps with lifetime of the bucket, 
557+             partially or fully (e.g. spec's lifetime subset of bucket's lifetime) 
558+               - If none of the specs in the bucket overlaps with spec's lifetime. 
559+                 - Allocate spec to the bucket at offset = 0. 
560+                 - Add this spec to the bucket's list of specs. 
561+                 - Expand bucket's lifetime accounting for added spec's lifetime. 
562+               - If one or more specs in the bucket overlaps with spec's lifetime. 
563+                 - Collect offsets (at which the given overlapping spec is allocated in the bucket). 
564+                   of all the overlapping specs, and find the max offset. 
565+                 - Allocate spec to the bucket at offset = max_offset + max_offset_spec_size. 
566+                 - Add this spec to the bucket's list of specs. 
567+                 - Expand bucket's lifetime accounting for added spec's lifetime. 
568+         - If none of these conditions are met, allocate a new bucket. 
569+             - Add spec to this bucket. 
570+             - Update bucket's lifetime to that of the spec. 
571+     - If the spec's size is larger than previously allocated buckets, allocate a new bucket. 
572+         - Size and lifetime of this bucket is that of the spec 
573+ 
574+     Proof of correctness: 
575+     - If allocating a new bucket, it is correct. 
576+     - If allocating spec to an existing bucket, whose lifetime does not overlap with any 
577+       of the previously allocated specs' lifetime, then the allocation is correct. 
578+     Proof of correctness by induction when adding spec to an existing bucket: 
579+     - If all previous allocations in the given bucket are correct: 
580+         - Then the new one being added must be correct because when the requested allocation 
581+           overlaps with one or more previous allocations, we find the largest offset among 
582+           all the overlapping allocations, and allocate the new spec at that offset. Hence, 
583+           the allocation at such an offset, will not overlap with any previous allocations. 
584+     Base case: A newly added allocation within a bucket with single allocation is correct: 
585+     because a) it must fit and b) its lifetime must not overlap with object's lifetime. 
586+     This holds true because of the following invariants: 
587+     - Once a bucket is created, it is never resized. 
588+     - All the allocations within a bucket follow this: 
589+       - Span, defined by allocation's offset + size, of two allocations can only overlap, 
590+         if their timelines do not overlap. 
506591    """ 
507-     # TODO: do better than linear scan 
508592    picked  =  None 
509593    for  sobj  in  shared_objects :
510-         if  spec .lifetime [0 ] >  sobj .last_used_index :
511-             if  picked  is  None  or  _size_abs_dif (sobj , spec ) <  _size_abs_dif (
512-                 picked , spec 
513-             ):
514-                 picked  =  sobj 
515-                 sobj .last_used_index  =  spec .lifetime [1 ]
516-                 sobj .size  =  max (sobj .size , spec .allocated_memory )
594+         if  _does_not_overlap (sobj , spec ):
595+             assert  sobj .size  >=  spec .allocated_memory , "Allocation specs are not sorted" 
596+             picked  =  sobj 
597+             sobj .first_used_index  =  min (sobj .first_used_index , spec .lifetime [0 ])
598+             sobj .last_used_index  =  max (sobj .last_used_index , spec .lifetime [1 ])
599+             allocation_spec  =  AllocationSpec (0 , spec )
600+             picked .allocations .append (allocation_spec )
601+             break 
602+ 
603+     if  picked  is  None :
604+         for  sobj  in  shared_objects :
605+             max_offset  =  _find_max_overlapping_allocations_offset (sobj , spec )
606+             if  max_offset  >  0 :
607+                 if  max_offset  +  spec .allocated_memory  <=  sobj .size :
608+                     picked  =  sobj 
609+                     sobj .first_used_index  =  min (sobj .first_used_index , spec .lifetime [0 ])
610+                     sobj .last_used_index  =  max (sobj .last_used_index , spec .lifetime [1 ])
611+                     allocation_spec  =  AllocationSpec (max_offset , spec )
612+                     picked .allocations .append (allocation_spec )
613+                     break 
614+ 
517615    if  picked  is  None :
518616        picked  =  SharedObject (
519-             len (shared_objects ), - 1 , spec .allocated_memory , spec .lifetime [1 ]
617+             len (shared_objects ),
618+             - 1 ,
619+             spec .allocated_memory ,
620+             spec .lifetime [0 ],
621+             spec .lifetime [1 ],
520622        )
623+         allocation_spec  =  AllocationSpec (0 , spec )
624+         picked .allocations .append (allocation_spec )
625+         picked .first_used_index  =  spec .lifetime [0 ]
626+         picked .last_used_index  =  spec .lifetime [1 ]
521627        shared_objects .append (picked )
522628
523629    return  picked 
@@ -565,13 +671,20 @@ def greedy(
565671    # For each tensor, pick the available shared object with closest size to 
566672    # the tensor. If there are no available shared object left, create a new 
567673    # one. 
674+     import  bisect 
675+ 
676+     sorted_specs  =  []
568677    for  spec  in  collect_specs_from_nodes (
569678        graph_module .graph .nodes ,
570679        graph_signature ,
571680        do_assertion = do_assertion ,
572681        ignore_graph_input = not  alloc_graph_input ,
573682        ignore_graph_output = not  alloc_graph_output ,
574683    ):
684+         bisect .insort (sorted_specs , spec , key = lambda  x : x .allocated_memory )
685+     sorted_specs .reverse ()
686+ 
687+     for  spec  in  sorted_specs :
575688        if  spec .mem_id  is  None :
576689            spec .mem_id  =  1 
577690        spec .realign (alignment )
@@ -583,6 +696,7 @@ def greedy(
583696        total_sizes  =  [0 , 0 ]
584697    else :
585698        total_sizes  =  [0 ] *  (max (shared_objects .keys ()) +  1 )
699+         num_specs_processed  =  0 
586700        for  mem_id  in  shared_objects :
587701            input_total_size  =  0 
588702            if  bufsizes  :=  getattr (graph_module , "input_mem_buffer_sizes" , None ):
@@ -594,13 +708,18 @@ def greedy(
594708            total_sizes [mem_id ] =  materialize_buffer (
595709                shared_objects [mem_id ], input_total_size 
596710            )
597- 
598-         # Since we now know the number of shared objects we need and the size of 
599-         # each shared object, we can assign offset in the memory buffer for each 
600-         # shared object. 
601-         for  spec , sobj  in  spec2obj .items ():
602-             spec .mem_obj_id  =  sobj .idx 
603-             spec .mem_offset  =  sobj .offset 
711+             # Since we now know the number of shared objects we need and the size of 
712+             # each shared object, we can assign offset in the memory buffer for each 
713+             # shared object. 
714+             for  sobj  in  shared_objects [mem_id ]:
715+                 for  alloc  in  sobj .allocations :
716+                     spec  =  alloc .spec 
717+                     alloc .spec .mem_obj_id  =  sobj .idx 
718+                     alloc .spec .mem_offset  =  sobj .offset  +  alloc .offset 
719+                     num_specs_processed  +=  1 
720+         assert  (
721+             len (spec2obj ) ==  num_specs_processed 
722+         ), f"All specs should be processed but there were { len (spec2obj )}   specs and processed { num_specs_processed }   specs" 
604723
605724    logging .debug (f"greedy algorithm returns bufsizes: { total_sizes }  " )
606725    return  total_sizes 
0 commit comments