66
77# pyre-strict
88
9+ import collections
10+ import functools
911import itertools
1012import logging
1113import operator
@@ -522,6 +524,27 @@ class SharedObject:
522524 def __repr__ (self ) -> str :
523525 return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
524526
527+ @dataclass
528+ class SpecAllocResult :
529+ """ These are the values that a memory plannig algorithm assigns to a spec.
530+ These are not directly written back into the spec object, but are used to
531+ track the allocation decisions and assigned back to the spec object in the
532+ end, based on which algorithm is picked as the best performing one.
533+ """
534+ mem_id : int
535+ mem_obj_id : int
536+ mem_offset : int
537+
538+ @dataclass
539+ class MemoryAlgoResult :
540+ """ This is the result returned by a memory planning algorithm that is
541+ invoked by memory_planning_algorithm_suite. It contains the allocation
542+ decisions of that algorithm for all the specs, and the size of the buffer
543+ that was used for different memory hierarchies.
544+ """
545+ spec_dict : Dict [TensorSpec , SpecAllocResult ]
546+ bufsizes : List [int ]
547+
525548
526549def materialize_buffer (
527550 shared_objects : List [SharedObject ], input_total_size : int = 0
@@ -711,7 +734,7 @@ def greedy(
711734 alloc_graph_input : bool = True ,
712735 alloc_graph_output : bool = True ,
713736 allow_overlapping_allocations : bool = True ,
714- ) -> List [ int ] :
737+ ) -> MemoryAlgoResult :
715738 r"""Greedy algorithm to allocate memory for tensors in the graph.
716739 alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
717740 alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -720,6 +743,7 @@ def greedy(
720743 This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
721744 allocations disabled
722745 """
746+ greedy_result = MemoryAlgoResult ({}, [])
723747 # padding allocation with 64 bytes.
724748 # this requirement is really for XNNPACK backend which can read tensors
725749 # beyond the end of the tensor. This is done for performance
@@ -754,12 +778,16 @@ def greedy(
754778 sorted_specs .reverse ()
755779
756780 for spec in sorted_specs :
781+ # Create an entry for this TensorSpec in the result object that we'll be
782+ # returning from this algorithm.
783+ spec_alloc_result = greedy_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
757784 if spec .mem_id is None :
758- spec .mem_id = 1
785+ spec_alloc_result .mem_id = 1
786+ else :
787+ spec_alloc_result .mem_id = spec .mem_id
788+ greedy_result .spec_dict [spec ] = spec_alloc_result
759789 spec .realign (alignment )
760- spec2obj [spec ] = pick_shared_obj (
761- shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
762- )
790+ spec2obj [spec ] = pick_shared_obj (shared_objects [spec_alloc_result .mem_id ], spec , allow_overlapping_allocations )
763791
764792 if len (shared_objects ) == 0 :
765793 # Cannot find any tensor in the graph that needs to be allocated.
@@ -787,24 +815,73 @@ def greedy(
787815 for sobj in shared_objects [mem_id ]:
788816 for alloc in sobj .allocations :
789817 spec = alloc .spec
790- alloc .spec .mem_obj_id = sobj .idx
791- alloc .spec .mem_offset = sobj .offset + alloc .offset
818+ # Get the spec_alloc_result for this spec and update it with the
819+ # mem_obj_id and mem_offset generated by this algorithm.
820+ spec_alloc_result = greedy_result .spec_dict .get (spec , None )
821+ assert spec_alloc_result is not None , f"Spec { spec } not found."
822+ spec_alloc_result .mem_obj_id = sobj .idx
823+ spec_alloc_result .mem_offset = sobj .offset + alloc .offset
792824 num_specs_processed += 1
793825 assert (
794826 len (spec2obj ) == num_specs_processed
795827 ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
796828
797829 logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
798- return total_sizes
830+ greedy_result .bufsizes = total_sizes
831+ return greedy_result
799832
833+ def memory_planning_algorithm_suite (
834+ graph_module : torch .fx .GraphModule ,
835+ alignment : int ,
836+ graph_signature : Optional [ExportGraphSignature ] = None ,
837+ alloc_graph_input : bool = True ,
838+ alloc_graph_output : bool = True ,
839+ allow_overlapping_allocations : bool = True ,
840+ algo_list : List [Callable [..., MemoryAlgoResult ]] = [greedy ],
841+ ) -> List [int ]:
842+ r"""
843+ Memory planning algorithm suite that runs a list of memory planning algorithms
844+ and returns the result of the algorithm that minimizes the total memory usage.
845+ """
846+ mem_algo_results = {}
847+ for algo in algo_list :
848+ if isinstance (algo , functools .partial ):
849+ name = algo .func .__name__
850+ else :
851+ name = getattr (algo , "__name__" , None )
852+ # Run this memory planning algorithm and store the result in mem_algo_results
853+ # with the name of the algorithm as the key.
854+ mem_algo_results [name ] = algo (
855+ graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
856+ )
857+
858+ # All the algorithms should have the same number of buffers allocated.
859+ assert len ({len (mem_algo_result .bufsizes ) for mem_algo_result in mem_algo_results .values ()}) == 1 , "Different memory planning algorithms should have the same number of buffers allocated."
860+
861+ # Find the algorithm that minimizes the total memory usage.
862+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
863+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
864+ bufsizes = mem_algo_results [best_algo ].bufsizes
865+
866+ # Update the mem_id and mem_offset for each spec in the graph module based on the
867+ # values provided by the best memory planning algorithm.
868+ for spec in mem_algo_results [best_algo ].spec_dict :
869+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
870+ spec .mem_id = spec_alloc_result .mem_id
871+ spec .mem_offset = spec_alloc_result .mem_offset
872+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
873+
874+ return bufsizes
800875
801876def naive (
802877 graph_module : torch .fx .GraphModule ,
803878 alignment : int ,
804879 graph_signature : Optional [ExportGraphSignature ] = None ,
805880 alloc_graph_input : bool = True ,
806881 alloc_graph_output : bool = True ,
807- ) -> List [int ]:
882+ ) -> MemoryAlgoResult :
883+
884+ naive_result = MemoryAlgoResult ({}, [])
808885
809886 # allocate 'allocated' bytes from buffer with id mem_id.
810887 # return the starting offset of the allocated buffer.
@@ -826,16 +903,22 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
826903 ignore_graph_input = not alloc_graph_input ,
827904 ignore_graph_output = not alloc_graph_output ,
828905 ):
906+ spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
829907 # assume a single memory layer which has mem_id 1
830908 if spec .mem_id is None :
831- spec .mem_id = 1
909+ spec_alloc_result .mem_id = 1
910+ else :
911+ spec_alloc_result .mem_id = spec .mem_id
912+ naive_result .spec_dict [spec ] = spec_alloc_result
913+
832914 # allocate spec.allocated_memory bytes in the buffer
833915 # with the corresponding mem_id
834916 spec .realign (alignment )
835- spec .mem_offset = _allocate_buf (bufsizes , spec .mem_id , spec .allocated_memory )
917+ spec_alloc_result .mem_offset = _allocate_buf (bufsizes , spec_alloc_result .mem_id , spec .allocated_memory )
836918
837919 logging .debug (f"naive algorithm returns bufsizes: { bufsizes } " )
838- return bufsizes
920+ naive_result .bufsizes = bufsizes
921+ return naive_result
839922
840923
841924def get_cond_nodes (graph_module : torch .fx .GraphModule ) -> Iterable [Node ]:
@@ -980,5 +1063,4 @@ def handle_submodule(
9801063 )
9811064
9821065 graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
983-
9841066 return bufsizes
0 commit comments