66
77# pyre-strict
88
9+ import collections
910import itertools
1011import logging
1112import operator
@@ -503,6 +504,17 @@ class SharedObject:
503504 def __repr__ (self ) -> str :
504505 return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
505506
507+ @dataclass
508+ class SpecAllocResult :
509+ mem_id : int
510+ mem_obj_id : int
511+ mem_offset : int
512+
513+ @dataclass
514+ class MemoryAlgoResult :
515+ spec_dict : Dict [TensorSpec , SpecAllocResult ]
516+ bufsizes : List [int ]
517+
506518
507519def materialize_buffer (
508520 shared_objects : List [SharedObject ], input_total_size : int = 0
@@ -692,7 +704,7 @@ def greedy(
692704 alloc_graph_input : bool = True ,
693705 alloc_graph_output : bool = True ,
694706 allow_overlapping_allocations : bool = True ,
695- ) -> List [ int ] :
707+ ) -> MemoryAlgoResult :
696708 r"""Greedy algorithm to allocate memory for tensors in the graph.
697709 alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698710 alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -701,6 +713,7 @@ def greedy(
701713 This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702714 allocations disabled
703715 """
716+ greedy_result = MemoryAlgoResult ({}, [])
704717 # padding allocation with 64 bytes.
705718 # this requirement is really for XNNPACK backend which can read tensors
706719 # beyond the end of the tensor. This is done for performance
@@ -735,12 +748,14 @@ def greedy(
735748 sorted_specs .reverse ()
736749
737750 for spec in sorted_specs :
751+ spec_alloc_result = greedy_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
738752 if spec .mem_id is None :
739- spec .mem_id = 1
753+ spec_alloc_result .mem_id = 1
754+ else :
755+ spec_alloc_result .mem_id = spec .mem_id
756+ greedy_result .spec_dict [spec ] = spec_alloc_result
740757 spec .realign (alignment )
741- spec2obj [spec ] = pick_shared_obj (
742- shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
743- )
758+ spec2obj [spec ] = pick_shared_obj (shared_objects [spec_alloc_result .mem_id ], spec , allow_overlapping_allocations )
744759
745760 if len (shared_objects ) == 0 :
746761 # Cannot find any tensor in the graph that needs to be allocated.
@@ -768,15 +783,18 @@ def greedy(
768783 for sobj in shared_objects [mem_id ]:
769784 for alloc in sobj .allocations :
770785 spec = alloc .spec
771- alloc .spec .mem_obj_id = sobj .idx
772- alloc .spec .mem_offset = sobj .offset + alloc .offset
786+ spec_alloc_result = greedy_result .spec_dict .get (spec , None )
787+ assert spec_alloc_result is not None , f"Spec { spec } not found."
788+ spec_alloc_result .mem_obj_id = sobj .idx
789+ spec_alloc_result .mem_offset = sobj .offset + alloc .offset
773790 num_specs_processed += 1
774791 assert (
775792 len (spec2obj ) == num_specs_processed
776793 ), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
777794
778795 logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
779- return total_sizes
796+ greedy_result .bufsizes = total_sizes
797+ return greedy_result
780798
781799
782800def naive (
@@ -785,7 +803,9 @@ def naive(
785803 graph_signature : Optional [ExportGraphSignature ] = None ,
786804 alloc_graph_input : bool = True ,
787805 alloc_graph_output : bool = True ,
788- ) -> List [int ]:
806+ ) -> MemoryAlgoResult :
807+
808+ naive_result = MemoryAlgoResult ({}, [])
789809
790810 # allocate 'allocated' bytes from buffer with id mem_id.
791811 # return the starting offset of the allocated buffer.
@@ -807,16 +827,22 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
807827 ignore_graph_input = not alloc_graph_input ,
808828 ignore_graph_output = not alloc_graph_output ,
809829 ):
830+ spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
810831 # assume a single memory layer which has mem_id 1
811832 if spec .mem_id is None :
812- spec .mem_id = 1
833+ spec_alloc_result .mem_id = 1
834+ else :
835+ spec_alloc_result .mem_id = spec .mem_id
836+ naive_result .spec_dict [spec ] = spec_alloc_result
837+
813838 # allocate spec.allocated_memory bytes in the buffer
814839 # with the corresponding mem_id
815840 spec .realign (alignment )
816- spec .mem_offset = _allocate_buf (bufsizes , spec .mem_id , spec .allocated_memory )
841+ spec_alloc_result .mem_offset = _allocate_buf (bufsizes , spec .mem_id , spec .allocated_memory )
817842
818843 logging .debug (f"naive algorithm returns bufsizes: { bufsizes } " )
819- return bufsizes
844+ naive_result .bufsizes = bufsizes
845+ return naive_result
820846
821847
822848def get_cond_nodes (graph_module : torch .fx .GraphModule ) -> Iterable [Node ]:
@@ -899,10 +925,10 @@ def insert_calls_to_free(
899925
900926
901927def apply_algo (
902- algo : Callable [
928+ algo_list : List [ Callable [
903929 [torch .fx .GraphModule , int , Optional [ExportGraphSignature ], bool , bool ],
904- List [ int ] ,
905- ],
930+ MemoryAlgoResult ,
931+ ]] ,
906932 graph_module : torch .fx .GraphModule ,
907933 alignment : int ,
908934 graph_signature : Optional [ExportGraphSignature ] = None ,
@@ -922,9 +948,24 @@ def apply_algo(
922948 """
923949
924950 specs = update_all_tensors_lifetime (graph_module , graph_signature )
925- bufsizes : List [int ] = algo (
926- graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
927- )
951+ mem_algo_results = {}
952+ for algo in algo_list :
953+ mem_algo_results [getattr (algo , "__name__" )] = algo (
954+ graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
955+ )
956+
957+ assert len ({len (mem_algo_result .bufsizes ) for mem_algo_result in mem_algo_results .values ()}) == 1 , "Different memory planning algorithms should have the same number of buffers allocated."
958+
959+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
960+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
961+ bufsizes = mem_algo_results [best_algo ].bufsizes
962+
963+ for spec in mem_algo_results [best_algo ].spec_dict :
964+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
965+ spec .mem_id = spec_alloc_result .mem_id
966+ spec .mem_offset = spec_alloc_result .mem_offset
967+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
968+
928969 insert_calls_to_free (graph_module , specs )
929970
930971 def handle_submodule (
@@ -937,7 +978,7 @@ def handle_submodule(
937978 # buffer already allocated.
938979 submodule .input_mem_buffer_sizes = bufsizes
939980 bufsizes = apply_algo (
940- algo ,
981+ algo_list ,
941982 submodule ,
942983 alignment ,
943984 graph_signature ,
@@ -961,5 +1002,5 @@ def handle_submodule(
9611002 )
9621003
9631004 graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
964-
1005+ print ( f"bufsizes = { bufsizes } " )
9651006 return bufsizes
0 commit comments