@@ -551,6 +551,7 @@ def greedy(
551551 graph_signature : Optional [ExportGraphSignature ] = None ,
552552 alloc_graph_input : bool = True ,
553553 alloc_graph_output : bool = True ,
554+ input_buffer_sizes : Optional [List [int ]] = None ,
554555) -> List [int ]:
555556 spec2obj = {}
556557 shared_objects = defaultdict (list )
@@ -574,18 +575,17 @@ def greedy(
574575
575576 if len (shared_objects ) == 0 :
576577 # Cannot find any tensor in the graph that needs to be allocated.
577- # Return [0, 0] to be consistent with default behavior of naive.
578- total_sizes = [0 , 0 ]
578+ # Return the input sizes or [0, 0] to be consistent with default behavior of naive.
579+ total_sizes = input_buffer_sizes or [0 , 0 ]
579580 else :
580- total_sizes = [0 ] * (max (shared_objects .keys ()) + 1 )
581+ num_buffers = max (shared_objects .keys ()) + 1
582+ if input_buffer_sizes is None :
583+ total_sizes = [0 ] * num_buffers
584+ else :
585+ total_sizes = input_buffer_sizes + [0 ] * (num_buffers - len (input_buffer_sizes ))
586+
581587 for mem_id in shared_objects :
582- input_total_size = 0
583- if bufsizes := getattr (graph_module , "input_mem_buffer_sizes" , None ):
584- if len (bufsizes ) > mem_id :
585- input_total_size = bufsizes [mem_id ]
586- total_sizes [mem_id ] = materialize_buffer (
587- shared_objects [mem_id ], input_total_size
588- )
588+ total_sizes [mem_id ] = materialize_buffer (shared_objects [mem_id ], total_sizes [mem_id ])
589589
590590 # Since we now know the number of shared objects we need and the size of
591591 # each shared object, we can assign offset in the memory buffer for each
@@ -604,6 +604,7 @@ def naive(
604604 graph_signature : Optional [ExportGraphSignature ] = None ,
605605 alloc_graph_input : bool = True ,
606606 alloc_graph_output : bool = True ,
607+ input_buffer_sizes : Optional [List [int ]] = None ,
607608) -> List [int ]:
608609
609610 # allocate 'allocated' bytes from buffer with id mem_id.
@@ -615,10 +616,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
615616 bufsizes [mem_id ] += allocated
616617 return ret
617618
618- bufsizes = getattr (graph_module , "input_mem_buffer_sizes" , None )
619- if bufsizes is None :
620- bufsizes = [0 , 0 ]
621-
619+ bufsizes = input_buffer_sizes or [0 , 0 ]
622620 bufsizes = typing .cast (List [int ], bufsizes )
623621 for spec in collect_specs_from_nodes (
624622 graph_module .graph .nodes ,
@@ -727,6 +725,8 @@ def apply_algo(
727725 graph_signature : Optional [ExportGraphSignature ] = None ,
728726 alloc_graph_input : bool = True ,
729727 alloc_graph_output : bool = True ,
728+ # the sizes of buffers already allocated when recursively applied on submodules
729+ input_buffer_sizes : Optional [List [int ]] = None ,
730730) -> List [int ]:
731731 """
732732 Recursively apply algo to graph_module and its submodules for control flow.
@@ -741,43 +741,46 @@ def apply_algo(
741741 """
742742 specs = update_all_tensors_lifetime (graph_module , graph_signature )
743743 bufsizes : List [int ] = algo (
744- graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
744+ graph_module ,
745+ alignment ,
746+ graph_signature ,
747+ alloc_graph_input ,
748+ alloc_graph_output ,
749+ input_buffer_sizes ,
745750 )
746751 insert_calls_to_free (graph_module , specs )
747752
748753 def handle_submodule (
749- submodule_nd : torch .fx .Node , alloc_graph_input : bool = False
754+ submodule_nd : torch .fx .Node , current_buffer_sizes , alloc_graph_input : bool = False
750755 ) -> None :
751- nonlocal bufsizes
752756 assert submodule_nd .op == "get_attr"
753757 submodule = getattr (graph_module , submodule_nd .target )
754- # memory planning for submodule need to be aware of the amount of
755- # buffer already allocated.
756- submodule .input_mem_buffer_sizes = bufsizes
758+ submodule .input_mem_buffer_sizes = current_buffer_sizes
757759 bufsizes = apply_algo (
758760 algo ,
759761 submodule ,
760762 alignment ,
761763 graph_signature ,
762764 alloc_graph_input = alloc_graph_input ,
763765 alloc_graph_output = True ,
766+ input_buffer_sizes = current_buffer_sizes ,
764767 )
765768 submodule .meta .update ({"non_const_buffer_sizes" : bufsizes })
769+ return bufsizes
766770
767771 for cond_node in get_cond_nodes (graph_module ):
768- handle_submodule (typing .cast (torch .fx .Node , cond_node .args [1 ]))
769- handle_submodule (typing .cast (torch .fx .Node , cond_node .args [2 ]))
772+ bufsizes = handle_submodule (typing .cast (torch .fx .Node , cond_node .args [1 ]), bufsizes )
773+ bufsizes = handle_submodule (typing .cast (torch .fx .Node , cond_node .args [2 ]), bufsizes )
770774
771775 for while_node in get_while_nodes (graph_module ):
772- handle_submodule (typing .cast (torch .fx .Node , while_node .args [0 ]))
773- handle_submodule (typing .cast (torch .fx .Node , while_node .args [1 ]))
776+ bufsizes = handle_submodule (typing .cast (torch .fx .Node , while_node .args [0 ]), bufsizes )
777+ bufsizes = handle_submodule (typing .cast (torch .fx .Node , while_node .args [1 ]), bufsizes )
774778 # TODO: Add test coverage for map operator once dynamo tracing is
775779 # fully supported for this. T142287208
776780 for map_node in get_map_nodes (graph_module ):
777- handle_submodule (
778- typing .cast (torch .fx .Node , map_node .args [0 ]), alloc_graph_input = True
781+ bufsizes = handle_submodule (
782+ typing .cast (torch .fx .Node , map_node .args [0 ]), bufsizes , alloc_graph_input = True
779783 )
780784
781785 graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
782-
783786 return bufsizes
0 commit comments