66
77# pyre-strict
88
9+ import collections
10+ import functools
911import itertools
1012import logging
1113import operator
@@ -503,6 +505,27 @@ class SharedObject:
503505 def __repr__ (self ) -> str :
504506 return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
505507
508+ @dataclass
509+ class SpecAllocResult :
510+ """ These are the values that a memory plannig algorithm assigns to a spec.
511+ These are not directly written back into the spec object, but are used to
512+ track the allocation decisions and assigned back to the spec object in the
513+ end, based on which algorithm is picked as the best performing one.
514+ """
515+ mem_id : int
516+ mem_obj_id : int
517+ mem_offset : int
518+
519+ @dataclass
520+ class MemoryAlgoResult :
521+ """ This is the result returned by a memory planning algorithm that is
522+ invoked by memory_planning_algorithm_suite. It contains the allocation
523+ decisions of that algorithm for all the specs, and the size of the buffer
524+ that was used for different memory hierarchies.
525+ """
526+ spec_dict : Dict [TensorSpec , SpecAllocResult ]
527+ bufsizes : List [int ]
528+
506529
507530def materialize_buffer (
508531 shared_objects : List [SharedObject ], input_total_size : int = 0
@@ -692,7 +715,7 @@ def greedy(
692715 alloc_graph_input : bool = True ,
693716 alloc_graph_output : bool = True ,
694717 allow_overlapping_allocations : bool = True ,
695- ) -> List [ int ] :
718+ ) -> MemoryAlgoResult :
696719 r"""Greedy algorithm to allocate memory for tensors in the graph.
697720 alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698721 alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -701,6 +724,7 @@ def greedy(
701724 This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702725 allocations disabled
703726 """
727+ greedy_result = MemoryAlgoResult ({}, [])
704728 # padding allocation with 64 bytes.
705729 # this requirement is really for XNNPACK backend which can read tensors
706730 # beyond the end of the tensor. This is done for performance
@@ -735,12 +759,16 @@ def greedy(
735759 sorted_specs .reverse ()
736760
737761 for spec in sorted_specs :
762+ # Create an entry for this TensorSpec in the result object that we'll be
763+ # returning from this algorithm.
764+ spec_alloc_result = greedy_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
738765 if spec .mem_id is None :
739- spec .mem_id = 1
766+ spec_alloc_result .mem_id = 1
767+ else :
768+ spec_alloc_result .mem_id = spec .mem_id
769+ greedy_result .spec_dict [spec ] = spec_alloc_result
740770 spec .realign (alignment )
741- spec2obj [spec ] = pick_shared_obj (
742- shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
743- )
771+ spec2obj [spec ] = pick_shared_obj (shared_objects [spec_alloc_result .mem_id ], spec , allow_overlapping_allocations )
744772
745773 if len (shared_objects ) == 0 :
746774 # Cannot find any tensor in the graph that needs to be allocated.
@@ -768,24 +796,73 @@ def greedy(
768796 for sobj in shared_objects [mem_id ]:
769797 for alloc in sobj .allocations :
770798 spec = alloc .spec
771- alloc .spec .mem_obj_id = sobj .idx
772- alloc .spec .mem_offset = sobj .offset + alloc .offset
799+ # Get the spec_alloc_result for this spec and update it with the
800+ # mem_obj_id and mem_offset generated by this algorithm.
801+ spec_alloc_result = greedy_result .spec_dict .get (spec , None )
802+ assert spec_alloc_result is not None , f"Spec { spec } not found."
803+ spec_alloc_result .mem_obj_id = sobj .idx
804+ spec_alloc_result .mem_offset = sobj .offset + alloc .offset
773805 num_specs_processed += 1
774806 assert (
775807 len (spec2obj ) == num_specs_processed
776808 ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
777809
778810 logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
779- return total_sizes
811+ greedy_result .bufsizes = total_sizes
812+ return greedy_result
780813
814+ def memory_planning_algorithm_suite (
815+ graph_module : torch .fx .GraphModule ,
816+ alignment : int ,
817+ graph_signature : Optional [ExportGraphSignature ] = None ,
818+ alloc_graph_input : bool = True ,
819+ alloc_graph_output : bool = True ,
820+ allow_overlapping_allocations : bool = True ,
821+ algo_list : List [Callable [..., MemoryAlgoResult ]] = [greedy ],
822+ ) -> List [int ]:
823+ r"""
824+ Memory planning algorithm suite that runs a list of memory planning algorithms
825+ and returns the result of the algorithm that minimizes the total memory usage.
826+ """
827+ mem_algo_results = {}
828+ for algo in algo_list :
829+ if isinstance (algo , functools .partial ):
830+ name = algo .func .__name__
831+ else :
832+ name = getattr (algo , "__name__" , None )
833+ # Run this memory planning algorithm and store the result in mem_algo_results
834+ # with the name of the algorithm as the key.
835+ mem_algo_results [name ] = algo (
836+ graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
837+ )
838+
839+ # All the algorithms should have the same number of buffers allocated.
840+ assert len ({len (mem_algo_result .bufsizes ) for mem_algo_result in mem_algo_results .values ()}) == 1 , "Different memory planning algorithms should have the same number of buffers allocated."
841+
842+ # Find the algorithm that minimizes the total memory usage.
843+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
844+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
845+ bufsizes = mem_algo_results [best_algo ].bufsizes
846+
847+ # Update the mem_id and mem_offset for each spec in the graph module based on the
848+ # values provided by the best memory planning algorithm.
849+ for spec in mem_algo_results [best_algo ].spec_dict :
850+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
851+ spec .mem_id = spec_alloc_result .mem_id
852+ spec .mem_offset = spec_alloc_result .mem_offset
853+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
854+
855+ return bufsizes
781856
782857def naive (
783858 graph_module : torch .fx .GraphModule ,
784859 alignment : int ,
785860 graph_signature : Optional [ExportGraphSignature ] = None ,
786861 alloc_graph_input : bool = True ,
787862 alloc_graph_output : bool = True ,
788- ) -> List [int ]:
863+ ) -> MemoryAlgoResult :
864+
865+ naive_result = MemoryAlgoResult ({}, [])
789866
790867 # allocate 'allocated' bytes from buffer with id mem_id.
791868 # return the starting offset of the allocated buffer.
@@ -807,16 +884,22 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
807884 ignore_graph_input = not alloc_graph_input ,
808885 ignore_graph_output = not alloc_graph_output ,
809886 ):
887+ spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
810888 # assume a single memory layer which has mem_id 1
811889 if spec .mem_id is None :
812- spec .mem_id = 1
890+ spec_alloc_result .mem_id = 1
891+ else :
892+ spec_alloc_result .mem_id = spec .mem_id
893+ naive_result .spec_dict [spec ] = spec_alloc_result
894+
813895 # allocate spec.allocated_memory bytes in the buffer
814896 # with the corresponding mem_id
815897 spec .realign (alignment )
816- spec .mem_offset = _allocate_buf (bufsizes , spec .mem_id , spec .allocated_memory )
898+ spec_alloc_result .mem_offset = _allocate_buf (bufsizes , spec_alloc_result .mem_id , spec .allocated_memory )
817899
818900 logging .debug (f"naive algorithm returns bufsizes: { bufsizes } " )
819- return bufsizes
901+ naive_result .bufsizes = bufsizes
902+ return naive_result
820903
821904
822905def get_cond_nodes (graph_module : torch .fx .GraphModule ) -> Iterable [Node ]:
@@ -961,5 +1044,4 @@ def handle_submodule(
9611044 )
9621045
9631046 graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
964-
9651047 return bufsizes
0 commit comments