Skip to content

Commit 3bf1cc2

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
allow not memory planning mutable buffers (#10071)
Summary: Pull Request resolved: #10071 Config option to not memory plan mutable buffers. This when paired with a future runtime PR will allow users to retrieve buffers by name in the runtime and then set their dataptr. Differential Revision: D72749868
1 parent 060cda3 commit 3bf1cc2

File tree

5 files changed

+80
-16
lines changed

5 files changed

+80
-16
lines changed

exir/emit/_emitter.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,13 +1640,25 @@ def placeholder( # noqa: C901
16401640
else:
16411641
spec.extra_tensor_info.fully_qualified_name = fqn
16421642
spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL
1643-
if self.emitter_state.emit_mutable_buffer_names and is_mutable_buffer:
1644-
if spec.extra_tensor_info is None:
1645-
spec.extra_tensor_info = ExtraTensorInfo(
1646-
fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT
1643+
1644+
if is_mutable_buffer:
1645+
# Emit names if we are supposed to.
1646+
if self.emitter_state.emit_mutable_buffer_names:
1647+
if spec.extra_tensor_info is None:
1648+
spec.extra_tensor_info = ExtraTensorInfo(
1649+
fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT
1650+
)
1651+
else:
1652+
spec.extra_tensor_info.fully_qualified_name = fqn
1653+
# if We aren't emitting the name then it needs to be memory planned.
1654+
elif spec.mem_id is None or spec.mem_offset is None:
1655+
raise InternalError(
1656+
self._emit_node_specific_error(
1657+
self.node,
1658+
# [2:] to remove the b_ prefix buffers get
1659+
f"Mutable buffer \"{target[2:]}\" must have a memory id and offset if we are emitting it without a name. Please either memory plan your mutable buffers or call to_executorch with config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)",
1660+
)
16471661
)
1648-
else:
1649-
spec.extra_tensor_info.fully_qualified_name = fqn
16501662

16511663
# From the fqn find the corresponding tensor
16521664
real_tensor = None

exir/emit/test/test_emit.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1838,8 +1838,38 @@ def forward(self, x):
18381838
ep = to_edge(ep)
18391839
# Lower the graph to executorch.
18401840
ep = ep.to_executorch(
1841-
config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)
1841+
config=ExecutorchBackendConfig(
1842+
emit_mutable_buffer_names=True,
1843+
memory_planning_pass=MemoryPlanningPass(alloc_mutable_buffers=False),
1844+
)
18421845
)
18431846
for val in ep.executorch_program.execution_plan[0].values:
18441847
if isinstance(val, Tensor) and val.extra_tensor_info:
18451848
self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer")
1849+
self.assertEqual(val.allocation_info, None)
1850+
1851+
def test_emit_mutable_buffer_names_fails(self) -> None:
1852+
class Net(nn.Module):
1853+
def __init__(self):
1854+
super().__init__()
1855+
self.linear = nn.Linear(2, 2)
1856+
self.register_buffer("buffer", torch.zeros(1, 2))
1857+
1858+
def forward(self, x):
1859+
self.buffer.add_(1)
1860+
return self.linear(x) + self.buffer
1861+
1862+
net = Net()
1863+
1864+
ep = export(net, (torch.randn(1, 2),), strict=True)
1865+
# Lower the graph to edge dialect.
1866+
ep = to_edge(ep)
1867+
# Lower the graph to executorch.
1868+
# Must emit mutable buffer names if we don't allocate mutable buffers
1869+
with self.assertRaises(InternalError):
1870+
ep.to_executorch(
1871+
config=ExecutorchBackendConfig(
1872+
emit_mutable_buffer_names=False,
1873+
memory_planning_pass=MemoryPlanningPass(alloc_mutable_buffers=False),
1874+
)
1875+
)

exir/memory_planning.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def __init__(
4444
graph_module: torch.fx.GraphModule,
4545
alloc_graph_input: bool,
4646
alloc_graph_output: bool,
47+
alloc_mutable_buffers: bool,
4748
graph_signature: Optional[ExportGraphSignature] = None,
4849
) -> None:
4950
self.graph_module = graph_module
5051
self.graph_signature = graph_signature
5152
self.alloc_graph_input = alloc_graph_input
5253
self.alloc_graph_output = alloc_graph_output
54+
self.alloc_mutable_buffers = alloc_mutable_buffers
5355

5456
@classmethod
5557
def mem_obj_id_match(
@@ -149,6 +151,7 @@ def verify_storage_reuse(
149151
ignore_const=True,
150152
ignore_graph_input=not self.alloc_graph_input,
151153
ignore_graph_output=not self.alloc_graph_output,
154+
ignore_mutable_buffers=not self.alloc_mutable_buffers,
152155
do_assertion=False,
153156
ignore_out_var_node=False,
154157
dedup=True,
@@ -374,6 +377,7 @@ def collect_specs_from_nodes( # noqa: C901
374377
graph_signature: Optional[ExportGraphSignature] = None,
375378
ignore_graph_input: bool = False,
376379
ignore_graph_output: bool = False,
380+
ignore_mutable_buffers: bool = False,
377381
ignore_const: bool = True,
378382
ignore_out_var_node: bool = True,
379383
dedup: bool = True,
@@ -414,6 +418,12 @@ def collect_specs_from_nodes( # noqa: C901
414418
if _is_inplace_node(node):
415419
continue
416420

421+
if (
422+
_is_mutable_buffer(node, graph_signature)
423+
and ignore_mutable_buffers
424+
):
425+
continue
426+
417427
if do_assertion:
418428
internal_assert(
419429
node.op in ("placeholder", "output")
@@ -469,6 +479,7 @@ def update_all_tensors_lifetime(
469479
Set the lifetime for all the tensors encountered in the Fx graph.
470480
"""
471481
specs = set()
482+
472483
for node_idx, node in enumerate(graph_module.graph.nodes):
473484
for spec in collect_specs_from_nodes(
474485
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
@@ -1053,6 +1064,7 @@ def apply_algo(
10531064
graph_signature: Optional[ExportGraphSignature] = None,
10541065
alloc_graph_input: bool = True,
10551066
alloc_graph_output: bool = True,
1067+
alloc_mutable_buffers: bool = True,
10561068
) -> List[int]:
10571069
"""
10581070
Recursively apply algo to graph_module and its submodules for control flow.
@@ -1065,19 +1077,18 @@ def apply_algo(
10651077
storage with tensors in the outer module.
10661078
TODO: make these optimizations once we have some baseline working.
10671079
"""
1068-
10691080
# Extract the nodes and their lifespans from the graph_module
10701081
# Difficult to just filter the list of specs returned by this due to
10711082
# how we flag trainable weights.
10721083
_ = update_all_tensors_lifetime(graph_module, graph_signature)
1073-
10741084
# Filter specs based on alloc_graph_input and alloc_graph_output
10751085
specs = collect_specs_from_nodes(
10761086
graph_module.graph.nodes,
10771087
graph_signature,
10781088
do_assertion=False,
10791089
ignore_graph_input=not alloc_graph_input,
10801090
ignore_graph_output=not alloc_graph_output,
1091+
ignore_mutable_buffers=not alloc_mutable_buffers,
10811092
)
10821093

10831094
# Get extra padding for XNNPACK if needed

exir/passes/memory_planning_pass.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
allow_lifetime_and_storage_overlap: bool = False,
4545
alloc_graph_input: bool = True,
4646
alloc_graph_output: bool = True,
47+
alloc_mutable_buffers: bool = True,
4748
alignment: int = ALIGNMENT,
4849
) -> None:
4950
r"""
@@ -54,10 +55,11 @@ def __init__(
5455
"""
5556
if memory_planning_algo is None:
5657
memory_planning_algo = MemoryPlanningAlgorithmSuite()
57-
self.memory_planning_algo = memory_planning_algo
58+
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
5859
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
5960
self.alloc_graph_input = alloc_graph_input
6061
self.alloc_graph_output = alloc_graph_output
62+
self.alloc_mutable_buffers = alloc_mutable_buffers
6163
self.alignment = alignment
6264

6365
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
@@ -124,13 +126,15 @@ def run(
124126
# customized fields. Using the graph_module object to convey information across
125127
# passes/stages is quite natural and avoid yet another 'context' data structure
126128
# to do the job.
129+
127130
_ = apply_algo(
128-
self.memory_planning_algo, # pyre-ignore[6]
131+
self.memory_planning_algo,
129132
graph_module,
130133
self.alignment,
131134
graph_signature,
132135
self.alloc_graph_input,
133136
self.alloc_graph_output,
137+
self.alloc_mutable_buffers
134138
)
135139

136140
# TODO: make the verifier do the work recursively to handle
@@ -139,6 +143,7 @@ def run(
139143
graph_module,
140144
self.alloc_graph_input,
141145
self.alloc_graph_output,
146+
self.alloc_mutable_buffers,
142147
graph_signature,
143148
)
144149

exir/tests/test_memory_planning.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def maketest(
241241
use_functionalization: bool = True,
242242
alloc_graph_input: bool = True,
243243
alloc_graph_output: bool = True,
244+
alloc_mutable_buffer: bool = True,
244245
has_unused_graph_input: bool = False,
245246
) -> Callable[..., None]:
246247
# parameterized.expand is not compatible with maketest. I'll just loop thru
@@ -282,10 +283,10 @@ def wrapper(self: "TestMemoryPlanning") -> None:
282283
)(graph_module).graph_module
283284

284285
self.verify_reuse(
285-
graph_module, expect_reuse, alloc_graph_input, alloc_graph_output
286+
graph_module, expect_reuse, alloc_graph_input, alloc_graph_output, alloc_mutable_buffer
286287
)
287288
self.verify_graph_input_output(
288-
graph_module, alloc_graph_input, alloc_graph_output
289+
graph_module, alloc_graph_input, alloc_graph_output, alloc_mutable_buffer
289290
)
290291

291292
self.verify_overlap_placeholders(has_unused_graph_input, graph_module)
@@ -306,6 +307,7 @@ def verify_reuse(
306307
expect_reuse: bool,
307308
alloc_graph_input: bool,
308309
alloc_graph_output: bool,
310+
alloc_mutable_buffer: bool,
309311
) -> None:
310312
r"""
311313
Do sanity check and verify tensor storage reuse.
@@ -321,6 +323,7 @@ def verify_reuse(
321323
graph_module,
322324
alloc_graph_input=alloc_graph_input,
323325
alloc_graph_output=alloc_graph_output,
326+
alloc_mutable_buffers=alloc_mutable_buffer,
324327
).verify_storage_reuse()
325328

326329
print(f"num_reuse_pairs is {num_reuse_pairs}")
@@ -334,9 +337,10 @@ def verify_graph_input_output(
334337
graph_module: torch.fx.GraphModule,
335338
alloc_graph_input: bool,
336339
alloc_graph_output: bool,
340+
alloc_mutable_buffers: bool,
337341
) -> None:
338342
Verifier(
339-
graph_module, alloc_graph_input, alloc_graph_output
343+
graph_module, alloc_graph_input, alloc_graph_output, alloc_mutable_buffers
340344
).verify_graph_input_output()
341345

342346
def verify_overlap_placeholders(
@@ -404,13 +408,14 @@ def verify_overlap_placeholders(
404408
)
405409

406410
def test_graph_input_output(self) -> None:
407-
for alloc_graph_input, alloc_graph_output in itertools.product(
408-
[True, False], [True, False]
411+
for alloc_graph_input, alloc_graph_output, alloc_mutable_buffers in itertools.product(
412+
[True, False], [True, False], [True, False]
409413
):
410414
case = maketest(
411415
ModelWithDifferentTensorSizes,
412416
alloc_graph_input=alloc_graph_input,
413417
alloc_graph_output=alloc_graph_output,
418+
alloc_mutable_buffer=alloc_mutable_buffers,
414419
)
415420
case(self)
416421

@@ -535,6 +540,7 @@ def test_multiple_pools(
535540
graph_module,
536541
alloc_graph_input=True,
537542
alloc_graph_output=True,
543+
alloc_mutable_buffers=True,
538544
)
539545
verifier.verify_storage_reuse()
540546
verifier.verify_graph_input_output()

0 commit comments

Comments
 (0)