Skip to content

Commit 14b7ba2

Browse files
JacobSzwejbkakeyprocedure
authored andcommitted
allow not memory planning mutable buffers
Differential Revision: D72749868 Pull Request resolved: pytorch#10071
1 parent 02234e8 commit 14b7ba2

File tree

5 files changed

+90
-17
lines changed

5 files changed

+90
-17
lines changed

exir/emit/_emitter.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,13 +1640,26 @@ def placeholder( # noqa: C901
16401640
else:
16411641
spec.extra_tensor_info.fully_qualified_name = fqn
16421642
spec.extra_tensor_info.location = TensorDataLocation.EXTERNAL
1643-
if self.emitter_state.emit_mutable_buffer_names and is_mutable_buffer:
1644-
if spec.extra_tensor_info is None:
1645-
spec.extra_tensor_info = ExtraTensorInfo(
1646-
fully_qualified_name=fqn, location=TensorDataLocation.SEGMENT
1643+
1644+
if is_mutable_buffer:
1645+
# Emit names if we are supposed to.
1646+
if self.emitter_state.emit_mutable_buffer_names:
1647+
if spec.extra_tensor_info is None:
1648+
spec.extra_tensor_info = ExtraTensorInfo(
1649+
fully_qualified_name=fqn,
1650+
location=TensorDataLocation.SEGMENT,
1651+
)
1652+
else:
1653+
spec.extra_tensor_info.fully_qualified_name = fqn
1654+
# if We aren't emitting the name then it needs to be memory planned.
1655+
elif spec.mem_id is None or spec.mem_offset is None:
1656+
raise InternalError(
1657+
self._emit_node_specific_error(
1658+
self.node,
1659+
# [2:] to remove the b_ prefix buffers get
1660+
f'Mutable buffer "{target[2:]}" must have a memory id and offset if we are emitting it without a name. Please either memory plan your mutable buffers or call to_executorch with config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)',
1661+
)
16471662
)
1648-
else:
1649-
spec.extra_tensor_info.fully_qualified_name = fqn
16501663

16511664
# From the fqn find the corresponding tensor
16521665
real_tensor = None

exir/emit/test/test_emit.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1838,8 +1838,40 @@ def forward(self, x):
18381838
ep = to_edge(ep)
18391839
# Lower the graph to executorch.
18401840
ep = ep.to_executorch(
1841-
config=ExecutorchBackendConfig(emit_mutable_buffer_names=True)
1841+
config=ExecutorchBackendConfig(
1842+
emit_mutable_buffer_names=True,
1843+
memory_planning_pass=MemoryPlanningPass(alloc_mutable_buffers=False),
1844+
)
18421845
)
18431846
for val in ep.executorch_program.execution_plan[0].values:
18441847
if isinstance(val, Tensor) and val.extra_tensor_info:
18451848
self.assertEqual(val.extra_tensor_info.fully_qualified_name, "buffer")
1849+
self.assertEqual(val.allocation_info, None)
1850+
1851+
def test_emit_mutable_buffer_names_fails(self) -> None:
1852+
class Net(nn.Module):
1853+
def __init__(self):
1854+
super().__init__()
1855+
self.linear = nn.Linear(2, 2)
1856+
self.register_buffer("buffer", torch.zeros(1, 2))
1857+
1858+
def forward(self, x):
1859+
self.buffer.add_(1)
1860+
return self.linear(x) + self.buffer
1861+
1862+
net = Net()
1863+
1864+
ep = export(net, (torch.randn(1, 2),), strict=True)
1865+
# Lower the graph to edge dialect.
1866+
ep = to_edge(ep)
1867+
# Lower the graph to executorch.
1868+
# Must emit mutable buffer names if we don't allocate mutable buffers
1869+
with self.assertRaises(InternalError):
1870+
ep.to_executorch(
1871+
config=ExecutorchBackendConfig(
1872+
emit_mutable_buffer_names=False,
1873+
memory_planning_pass=MemoryPlanningPass(
1874+
alloc_mutable_buffers=False
1875+
),
1876+
)
1877+
)

exir/memory_planning.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def __init__(
4444
graph_module: torch.fx.GraphModule,
4545
alloc_graph_input: bool,
4646
alloc_graph_output: bool,
47+
alloc_mutable_buffers: bool,
4748
graph_signature: Optional[ExportGraphSignature] = None,
4849
) -> None:
4950
self.graph_module = graph_module
5051
self.graph_signature = graph_signature
5152
self.alloc_graph_input = alloc_graph_input
5253
self.alloc_graph_output = alloc_graph_output
54+
self.alloc_mutable_buffers = alloc_mutable_buffers
5355

5456
@classmethod
5557
def mem_obj_id_match(
@@ -149,6 +151,7 @@ def verify_storage_reuse(
149151
ignore_const=True,
150152
ignore_graph_input=not self.alloc_graph_input,
151153
ignore_graph_output=not self.alloc_graph_output,
154+
ignore_mutable_buffers=not self.alloc_mutable_buffers,
152155
do_assertion=False,
153156
ignore_out_var_node=False,
154157
dedup=True,
@@ -374,6 +377,7 @@ def collect_specs_from_nodes( # noqa: C901
374377
graph_signature: Optional[ExportGraphSignature] = None,
375378
ignore_graph_input: bool = False,
376379
ignore_graph_output: bool = False,
380+
ignore_mutable_buffers: bool = False,
377381
ignore_const: bool = True,
378382
ignore_out_var_node: bool = True,
379383
dedup: bool = True,
@@ -414,6 +418,9 @@ def collect_specs_from_nodes( # noqa: C901
414418
if _is_inplace_node(node):
415419
continue
416420

421+
if _is_mutable_buffer(node, graph_signature) and ignore_mutable_buffers:
422+
continue
423+
417424
if do_assertion:
418425
internal_assert(
419426
node.op in ("placeholder", "output")
@@ -469,6 +476,7 @@ def update_all_tensors_lifetime(
469476
Set the lifetime for all the tensors encountered in the Fx graph.
470477
"""
471478
specs = set()
479+
472480
for node_idx, node in enumerate(graph_module.graph.nodes):
473481
for spec in collect_specs_from_nodes(
474482
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
@@ -1053,6 +1061,7 @@ def apply_algo(
10531061
graph_signature: Optional[ExportGraphSignature] = None,
10541062
alloc_graph_input: bool = True,
10551063
alloc_graph_output: bool = True,
1064+
alloc_mutable_buffers: bool = True,
10561065
) -> List[int]:
10571066
"""
10581067
Recursively apply algo to graph_module and its submodules for control flow.
@@ -1065,19 +1074,18 @@ def apply_algo(
10651074
storage with tensors in the outer module.
10661075
TODO: make these optimizations once we have some baseline working.
10671076
"""
1068-
10691077
# Extract the nodes and their lifespans from the graph_module
10701078
# Difficult to just filter the list of specs returned by this due to
10711079
# how we flag trainable weights.
10721080
_ = update_all_tensors_lifetime(graph_module, graph_signature)
1073-
10741081
# Filter specs based on alloc_graph_input and alloc_graph_output
10751082
specs = collect_specs_from_nodes(
10761083
graph_module.graph.nodes,
10771084
graph_signature,
10781085
do_assertion=False,
10791086
ignore_graph_input=not alloc_graph_input,
10801087
ignore_graph_output=not alloc_graph_output,
1088+
ignore_mutable_buffers=not alloc_mutable_buffers,
10811089
)
10821090

10831091
# Get extra padding for XNNPACK if needed

exir/passes/memory_planning_pass.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
allow_lifetime_and_storage_overlap: bool = False,
4545
alloc_graph_input: bool = True,
4646
alloc_graph_output: bool = True,
47+
alloc_mutable_buffers: bool = True,
4748
alignment: int = ALIGNMENT,
4849
) -> None:
4950
r"""
@@ -54,10 +55,11 @@ def __init__(
5455
"""
5556
if memory_planning_algo is None:
5657
memory_planning_algo = MemoryPlanningAlgorithmSuite()
57-
self.memory_planning_algo = memory_planning_algo
58+
self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo
5859
self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap
5960
self.alloc_graph_input = alloc_graph_input
6061
self.alloc_graph_output = alloc_graph_output
62+
self.alloc_mutable_buffers = alloc_mutable_buffers
6163
self.alignment = alignment
6264

6365
def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
@@ -124,13 +126,15 @@ def run(
124126
# customized fields. Using the graph_module object to convey information across
125127
# passes/stages is quite natural and avoid yet another 'context' data structure
126128
# to do the job.
129+
127130
_ = apply_algo(
128-
self.memory_planning_algo, # pyre-ignore[6]
131+
self.memory_planning_algo,
129132
graph_module,
130133
self.alignment,
131134
graph_signature,
132135
self.alloc_graph_input,
133136
self.alloc_graph_output,
137+
self.alloc_mutable_buffers,
134138
)
135139

136140
# TODO: make the verifier do the work recursively to handle
@@ -139,6 +143,7 @@ def run(
139143
graph_module,
140144
self.alloc_graph_input,
141145
self.alloc_graph_output,
146+
self.alloc_mutable_buffers,
142147
graph_signature,
143148
)
144149

exir/tests/test_memory_planning.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def maketest(
241241
use_functionalization: bool = True,
242242
alloc_graph_input: bool = True,
243243
alloc_graph_output: bool = True,
244+
alloc_mutable_buffer: bool = True,
244245
has_unused_graph_input: bool = False,
245246
) -> Callable[..., None]:
246247
# parameterized.expand is not compatible with maketest. I'll just loop thru
@@ -282,10 +283,17 @@ def wrapper(self: "TestMemoryPlanning") -> None:
282283
)(graph_module).graph_module
283284

284285
self.verify_reuse(
285-
graph_module, expect_reuse, alloc_graph_input, alloc_graph_output
286+
graph_module,
287+
expect_reuse,
288+
alloc_graph_input,
289+
alloc_graph_output,
290+
alloc_mutable_buffer,
286291
)
287292
self.verify_graph_input_output(
288-
graph_module, alloc_graph_input, alloc_graph_output
293+
graph_module,
294+
alloc_graph_input,
295+
alloc_graph_output,
296+
alloc_mutable_buffer,
289297
)
290298

291299
self.verify_overlap_placeholders(has_unused_graph_input, graph_module)
@@ -306,6 +314,7 @@ def verify_reuse(
306314
expect_reuse: bool,
307315
alloc_graph_input: bool,
308316
alloc_graph_output: bool,
317+
alloc_mutable_buffer: bool,
309318
) -> None:
310319
r"""
311320
Do sanity check and verify tensor storage reuse.
@@ -321,6 +330,7 @@ def verify_reuse(
321330
graph_module,
322331
alloc_graph_input=alloc_graph_input,
323332
alloc_graph_output=alloc_graph_output,
333+
alloc_mutable_buffers=alloc_mutable_buffer,
324334
).verify_storage_reuse()
325335

326336
print(f"num_reuse_pairs is {num_reuse_pairs}")
@@ -334,9 +344,10 @@ def verify_graph_input_output(
334344
graph_module: torch.fx.GraphModule,
335345
alloc_graph_input: bool,
336346
alloc_graph_output: bool,
347+
alloc_mutable_buffers: bool,
337348
) -> None:
338349
Verifier(
339-
graph_module, alloc_graph_input, alloc_graph_output
350+
graph_module, alloc_graph_input, alloc_graph_output, alloc_mutable_buffers
340351
).verify_graph_input_output()
341352

342353
def verify_overlap_placeholders(
@@ -404,13 +415,16 @@ def verify_overlap_placeholders(
404415
)
405416

406417
def test_graph_input_output(self) -> None:
407-
for alloc_graph_input, alloc_graph_output in itertools.product(
408-
[True, False], [True, False]
409-
):
418+
for (
419+
alloc_graph_input,
420+
alloc_graph_output,
421+
alloc_mutable_buffers,
422+
) in itertools.product([True, False], [True, False], [True, False]):
410423
case = maketest(
411424
ModelWithDifferentTensorSizes,
412425
alloc_graph_input=alloc_graph_input,
413426
alloc_graph_output=alloc_graph_output,
427+
alloc_mutable_buffer=alloc_mutable_buffers,
414428
)
415429
case(self)
416430

@@ -535,6 +549,7 @@ def test_multiple_pools(
535549
graph_module,
536550
alloc_graph_input=True,
537551
alloc_graph_output=True,
552+
alloc_mutable_buffers=True,
538553
)
539554
verifier.verify_storage_reuse()
540555
verifier.verify_graph_input_output()

0 commit comments

Comments
 (0)