@@ -301,7 +301,11 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
301
301
302
302
303
303
def update_tensor_lifetime (
304
- node : torch .fx .Node , spec : TensorSpec , node_idx : int
304
+ node : torch .fx .Node ,
305
+ spec : TensorSpec ,
306
+ node_idx : int ,
307
+ max_node_idx : int ,
308
+ gs : Optional [ExportGraphSignature ] = None ,
305
309
) -> None :
306
310
r"""
307
311
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
@@ -317,7 +321,12 @@ def update_tensor_lifetime(
317
321
start = 0
318
322
else :
319
323
start = node_idx if start is None or start > node_idx else start
320
- end = node_idx if end is None or end < node_idx else end
324
+
325
+ if node .op == "placeholder" and _is_mutable_buffer (node , gs ):
326
+ # mutable buffers are never freed
327
+ end = max_node_idx
328
+ else :
329
+ end = node_idx if end is None or end < node_idx else end
321
330
spec .lifetime = [start , end ]
322
331
323
332
@@ -497,7 +506,7 @@ def update_all_tensors_lifetime(
497
506
Set the lifetime for all the tensors encountered in the Fx graph.
498
507
"""
499
508
specs = set ()
500
-
509
+ max_node_idx = len ( graph_module . graph . nodes ) - 1
501
510
for node_idx , node in enumerate (graph_module .graph .nodes ):
502
511
for spec in collect_specs_from_nodes (
503
512
filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
@@ -509,7 +518,7 @@ def update_all_tensors_lifetime(
509
518
do_assertion = False ,
510
519
ignore_dynamic_unbound_tensor = False ,
511
520
):
512
- update_tensor_lifetime (node , spec , node_idx )
521
+ update_tensor_lifetime (node , spec , node_idx , max_node_idx , graph_signature )
513
522
specs .add (spec )
514
523
return specs
515
524
0 commit comments