Skip to content

Commit e5e7008

Browse files
sxufacebook-github-bot
authored andcommitted
Explicitly pass buffer sizes during memory planning when control flow submodule are around
Summary: It's less error prone to have the buffer sizes passed as parameter and return value than implicitly updated via `nonlocal` or reference stored on submodule. Fix a bug where if a new buffer is introduced within a submodule it gets ignored by the top level `apply_algo` call. Differential Revision: D65915559
1 parent 7b03a8b commit e5e7008

File tree

2 files changed

+92
-27
lines changed

2 files changed

+92
-27
lines changed

exir/memory_planning.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def greedy(
551551
graph_signature: Optional[ExportGraphSignature] = None,
552552
alloc_graph_input: bool = True,
553553
alloc_graph_output: bool = True,
554+
input_buffer_sizes: Optional[List[int]] = None,
554555
) -> List[int]:
555556
spec2obj = {}
556557
shared_objects = defaultdict(list)
@@ -574,18 +575,17 @@ def greedy(
574575

575576
if len(shared_objects) == 0:
576577
# Cannot find any tensor in the graph that needs to be allocated.
577-
# Return [0, 0] to be consistent with default behavior of naive.
578-
total_sizes = [0, 0]
578+
# Return the input sizes or [0, 0] to be consistent with default behavior of naive.
579+
total_sizes = input_buffer_sizes or [0, 0]
579580
else:
580-
total_sizes = [0] * (max(shared_objects.keys()) + 1)
581+
num_buffers = max(shared_objects.keys()) + 1
582+
if input_buffer_sizes is None:
583+
total_sizes = [0] * num_buffers
584+
else:
585+
total_sizes = input_buffer_sizes + [0] * (num_buffers - len(input_buffer_sizes))
586+
581587
for mem_id in shared_objects:
582-
input_total_size = 0
583-
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
584-
if len(bufsizes) > mem_id:
585-
input_total_size = bufsizes[mem_id]
586-
total_sizes[mem_id] = materialize_buffer(
587-
shared_objects[mem_id], input_total_size
588-
)
588+
total_sizes[mem_id] = materialize_buffer(shared_objects[mem_id], total_sizes[mem_id])
589589

590590
# Since we now know the number of shared objects we need and the size of
591591
# each shared object, we can assign offset in the memory buffer for each
@@ -604,6 +604,7 @@ def naive(
604604
graph_signature: Optional[ExportGraphSignature] = None,
605605
alloc_graph_input: bool = True,
606606
alloc_graph_output: bool = True,
607+
input_buffer_sizes: Optional[List[int]] = None,
607608
) -> List[int]:
608609

609610
# allocate 'allocated' bytes from buffer with id mem_id.
@@ -615,10 +616,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
615616
bufsizes[mem_id] += allocated
616617
return ret
617618

618-
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
619-
if bufsizes is None:
620-
bufsizes = [0, 0]
621-
619+
bufsizes = input_buffer_sizes or [0, 0]
622620
bufsizes = typing.cast(List[int], bufsizes)
623621
for spec in collect_specs_from_nodes(
624622
graph_module.graph.nodes,
@@ -727,6 +725,8 @@ def apply_algo(
727725
graph_signature: Optional[ExportGraphSignature] = None,
728726
alloc_graph_input: bool = True,
729727
alloc_graph_output: bool = True,
728+
# the sizes of buffers already allocated when recursively applied on submodules
729+
input_buffer_sizes: Optional[List[int]] = None,
730730
) -> List[int]:
731731
"""
732732
Recursively apply algo to graph_module and its submodules for control flow.
@@ -741,43 +741,46 @@ def apply_algo(
741741
"""
742742
specs = update_all_tensors_lifetime(graph_module, graph_signature)
743743
bufsizes: List[int] = algo(
744-
graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
744+
graph_module,
745+
alignment,
746+
graph_signature,
747+
alloc_graph_input,
748+
alloc_graph_output,
749+
input_buffer_sizes,
745750
)
746751
insert_calls_to_free(graph_module, specs)
747752

748753
def handle_submodule(
749-
submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
754+
submodule_nd: torch.fx.Node, current_buffer_sizes, alloc_graph_input: bool = False
750755
) -> None:
751-
nonlocal bufsizes
752756
assert submodule_nd.op == "get_attr"
753757
submodule = getattr(graph_module, submodule_nd.target)
754-
# memory planning for submodule need to be aware of the amount of
755-
# buffer already allocated.
756-
submodule.input_mem_buffer_sizes = bufsizes
758+
submodule.input_mem_buffer_sizes = current_buffer_sizes
757759
bufsizes = apply_algo(
758760
algo,
759761
submodule,
760762
alignment,
761763
graph_signature,
762764
alloc_graph_input=alloc_graph_input,
763765
alloc_graph_output=True,
766+
input_buffer_sizes=current_buffer_sizes,
764767
)
765768
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
769+
return bufsizes
766770

767771
for cond_node in get_cond_nodes(graph_module):
768-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
769-
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))
772+
bufsizes = handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]), bufsizes)
773+
bufsizes = handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]), bufsizes)
770774

771775
for while_node in get_while_nodes(graph_module):
772-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
773-
handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
776+
bufsizes = handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]), bufsizes)
777+
bufsizes = handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]), bufsizes)
774778
# TODO: Add test coverage for map operator once dynamo tracing is
775779
# fully supported for this. T142287208
776780
for map_node in get_map_nodes(graph_module):
777-
handle_submodule(
778-
typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
781+
bufsizes = handle_submodule(
782+
typing.cast(torch.fx.Node, map_node.args[0]), bufsizes, alloc_graph_input=True
779783
)
780784

781785
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
782-
783786
return bufsizes

exir/tests/test_memory_planning.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,68 @@ def test_multiple_pools(
524524
idx += 1
525525
self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes)
526526

527+
def test_multiple_pools_with_cond(self) -> None:
528+
class MultiplePoolsWithCondToyModel(torch.nn.Module):
529+
def forward(self, b, x):
530+
def true_fn(x):
531+
return x + x
532+
533+
def false_fn(x):
534+
return x * x
535+
536+
return torch.cond(b, true_fn, false_fn, (x,))
537+
538+
edge_program = to_edge(
539+
export(
540+
MultiplePoolsWithCondToyModel(),
541+
(torch.tensor([True], dtype=torch.bool), torch.ones(1)),
542+
)
543+
)
544+
545+
edge_program.to_executorch(
546+
exir.ExecutorchBackendConfig(
547+
memory_planning_pass=CustomPoolMemoryPlanningPass(
548+
memory_planning_algo=greedy,
549+
alignment=1,
550+
),
551+
)
552+
)
553+
graph_module = edge_program.exported_program().graph_module
554+
555+
verifier = Verifier(
556+
graph_module,
557+
alloc_graph_input=True,
558+
alloc_graph_output=True,
559+
)
560+
verifier.verify_storage_reuse()
561+
verifier.verify_graph_input_output()
562+
563+
true_gm = None
564+
false_gm = None
565+
for node in graph_module.graph.nodes:
566+
if node.target == torch.ops.higher_order.cond:
567+
true_gm = getattr(graph_module, node.args[1].target)
568+
false_gm = getattr(graph_module, node.args[2].target)
569+
570+
self.assertTrue(true_gm is not None and false_gm is not None)
571+
for node in true_gm.graph.nodes:
572+
if node.op == "call_function" and node.target == torch.ops.aten.add.out:
573+
# true_fn calls add, for which the custom planning assign mem_id 3
574+
spec = node.meta.get("spec")
575+
self.assertTrue(spec is not None)
576+
self.assertEqual(spec.mem_id, 3)
577+
self.assertEqual(spec.mem_offset, 0)
578+
579+
for node in false_gm.graph.nodes:
580+
if node.op == "call_function" and node.target == torch.ops.aten.mul.out:
581+
# false_fn calls mul, for which the custom planning assign mem_id 1
582+
spec = node.meta.get("spec")
583+
self.assertTrue(spec is not None)
584+
self.assertEqual(spec.mem_id, 1)
585+
self.assertEqual(spec.mem_offset, 9)
586+
587+
self.assertEqual(graph_module.meta["non_const_buffer_sizes"], [0, 13, 0, 4])
588+
527589
def test_constants_not_memory_planned(self) -> None:
528590
class Simple(torch.nn.Module):
529591
def __init__(self) -> None:

0 commit comments

Comments
 (0)