66
77# pyre-strict
88
9+ import functools
910import itertools
1011import logging
1112import operator
@@ -523,6 +524,31 @@ def __repr__(self) -> str:
523524 return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
524525
525526
527+ @dataclass
528+ class SpecAllocResult :
529+ """These are the values that a memory plannig algorithm assigns to a spec.
530+ These are not directly written back into the spec object, but are used to
531+ track the allocation decisions and assigned back to the spec object in the
532+ end, based on which algorithm is picked as the best performing one.
533+ """
534+
535+ mem_id : int
536+ mem_obj_id : int
537+ mem_offset : int
538+
539+
540+ @dataclass
541+ class MemoryAlgoResult :
542+ """This is the result returned by a memory planning algorithm that is
543+ invoked by memory_planning_algorithm_suite. It contains the allocation
544+ decisions of that algorithm for all the specs, and the size of the buffer
545+ that was used for different memory hierarchies.
546+ """
547+
548+ spec_dict : Dict [TensorSpec , SpecAllocResult ]
549+ bufsizes : List [int ]
550+
551+
526552def materialize_buffer (
527553 shared_objects : List [SharedObject ], input_total_size : int = 0
528554) -> int :
@@ -711,7 +737,7 @@ def greedy(
711737 alloc_graph_input : bool = True ,
712738 alloc_graph_output : bool = True ,
713739 allow_overlapping_allocations : bool = True ,
714- ) -> List [ int ] :
740+ ) -> MemoryAlgoResult :
715741 r"""Greedy algorithm to allocate memory for tensors in the graph.
716742 alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
717743 alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -720,6 +746,7 @@ def greedy(
720746 This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
721747 allocations disabled
722748 """
749+ greedy_result = MemoryAlgoResult ({}, [])
723750 # padding allocation with 64 bytes.
724751 # this requirement is really for XNNPACK backend which can read tensors
725752 # beyond the end of the tensor. This is done for performance
@@ -754,11 +781,19 @@ def greedy(
754781 sorted_specs .reverse ()
755782
756783 for spec in sorted_specs :
784+ # Create an entry for this TensorSpec in the result object that we'll be
785+ # returning from this algorithm.
786+ spec_alloc_result = greedy_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
757787 if spec .mem_id is None :
758- spec .mem_id = 1
788+ spec_alloc_result .mem_id = 1
789+ else :
790+ spec_alloc_result .mem_id = spec .mem_id
791+ greedy_result .spec_dict [spec ] = spec_alloc_result
759792 spec .realign (alignment )
760793 spec2obj [spec ] = pick_shared_obj (
761- shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
794+ shared_objects [spec_alloc_result .mem_id ],
795+ spec ,
796+ allow_overlapping_allocations ,
762797 )
763798
764799 if len (shared_objects ) == 0 :
@@ -787,24 +822,89 @@ def greedy(
787822 for sobj in shared_objects [mem_id ]:
788823 for alloc in sobj .allocations :
789824 spec = alloc .spec
790- alloc .spec .mem_obj_id = sobj .idx
791- alloc .spec .mem_offset = sobj .offset + alloc .offset
825+ # Get the spec_alloc_result for this spec and update it with the
826+ # mem_obj_id and mem_offset generated by this algorithm.
827+ spec_alloc_result = greedy_result .spec_dict .get (spec , None )
828+ assert spec_alloc_result is not None , f"Spec { spec } not found."
829+ spec_alloc_result .mem_obj_id = sobj .idx
830+ spec_alloc_result .mem_offset = sobj .offset + alloc .offset
792831 num_specs_processed += 1
793832 assert (
794833 len (spec2obj ) == num_specs_processed
795834 ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
796835
797836 logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
798- return total_sizes
837+ greedy_result .bufsizes = total_sizes
838+ return greedy_result
799839
800840
801- def naive (
841+ def memory_planning_algorithm_suite (
802842 graph_module : torch .fx .GraphModule ,
803843 alignment : int ,
804844 graph_signature : Optional [ExportGraphSignature ] = None ,
805845 alloc_graph_input : bool = True ,
806846 alloc_graph_output : bool = True ,
847+ allow_overlapping_allocations : bool = True ,
848+ algo_list : Optional [List [Callable [..., MemoryAlgoResult ]]] = None ,
807849) -> List [int ]:
850+ r"""
851+ Memory planning algorithm suite that runs a list of memory planning algorithms
852+ and returns the result of the algorithm that minimizes the total memory usage.
853+ """
854+ if algo_list is None :
855+ algo_list = [greedy ]
856+ mem_algo_results = {}
857+ for algo in algo_list :
858+ if isinstance (algo , functools .partial ):
859+ name = algo .func .__name__
860+ else :
861+ name = getattr (algo , "__name__" , None )
862+ # Run this memory planning algorithm and store the result in mem_algo_results
863+ # with the name of the algorithm as the key.
864+ mem_algo_results [name ] = algo (
865+ graph_module ,
866+ alignment ,
867+ graph_signature ,
868+ alloc_graph_input ,
869+ alloc_graph_output ,
870+ )
871+
872+ # All the algorithms should have the same number of buffers allocated.
873+ assert (
874+ len (
875+ {
876+ len (mem_algo_result .bufsizes )
877+ for mem_algo_result in mem_algo_results .values ()
878+ }
879+ )
880+ == 1
881+ ), "Different memory planning algorithms should have the same number of buffers allocated."
882+
883+ # Find the algorithm that minimizes the total memory usage.
884+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
885+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
886+ bufsizes = mem_algo_results [best_algo ].bufsizes
887+
888+ # Update the mem_id and mem_offset for each spec in the graph module based on the
889+ # values provided by the best memory planning algorithm.
890+ for spec in mem_algo_results [best_algo ].spec_dict :
891+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
892+ spec .mem_id = spec_alloc_result .mem_id
893+ spec .mem_offset = spec_alloc_result .mem_offset
894+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
895+
896+ return bufsizes
897+
898+
899+ def naive (
900+ graph_module : torch .fx .GraphModule ,
901+ alignment : int ,
902+ graph_signature : Optional [ExportGraphSignature ] = None ,
903+ alloc_graph_input : bool = True ,
904+ alloc_graph_output : bool = True ,
905+ ) -> MemoryAlgoResult :
906+
907+ naive_result = MemoryAlgoResult ({}, [])
808908
809909 # allocate 'allocated' bytes from buffer with id mem_id.
810910 # return the starting offset of the allocated buffer.
@@ -826,16 +926,24 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
826926 ignore_graph_input = not alloc_graph_input ,
827927 ignore_graph_output = not alloc_graph_output ,
828928 ):
929+ spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
829930 # assume a single memory layer which has mem_id 1
830931 if spec .mem_id is None :
831- spec .mem_id = 1
932+ spec_alloc_result .mem_id = 1
933+ else :
934+ spec_alloc_result .mem_id = spec .mem_id
935+ naive_result .spec_dict [spec ] = spec_alloc_result
936+
832937 # allocate spec.allocated_memory bytes in the buffer
833938 # with the corresponding mem_id
834939 spec .realign (alignment )
835- spec .mem_offset = _allocate_buf (bufsizes , spec .mem_id , spec .allocated_memory )
940+ spec_alloc_result .mem_offset = _allocate_buf (
941+ bufsizes , spec_alloc_result .mem_id , spec .allocated_memory
942+ )
836943
837944 logging .debug (f"naive algorithm returns bufsizes: { bufsizes } " )
838- return bufsizes
945+ naive_result .bufsizes = bufsizes
946+ return naive_result
839947
840948
841949def get_cond_nodes (graph_module : torch .fx .GraphModule ) -> Iterable [Node ]:
@@ -980,5 +1088,4 @@ def handle_submodule(
9801088 )
9811089
9821090 graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
983-
9841091 return bufsizes
0 commit comments