@@ -44,12 +44,14 @@ def __init__(
44
44
graph_module : torch .fx .GraphModule ,
45
45
alloc_graph_input : bool ,
46
46
alloc_graph_output : bool ,
47
+ alloc_mutable_buffers : bool ,
47
48
graph_signature : Optional [ExportGraphSignature ] = None ,
48
49
) -> None :
49
50
self .graph_module = graph_module
50
51
self .graph_signature = graph_signature
51
52
self .alloc_graph_input = alloc_graph_input
52
53
self .alloc_graph_output = alloc_graph_output
54
+ self .alloc_mutable_buffers = alloc_mutable_buffers
53
55
54
56
@classmethod
55
57
def mem_obj_id_match (
@@ -149,6 +151,7 @@ def verify_storage_reuse(
149
151
ignore_const = True ,
150
152
ignore_graph_input = not self .alloc_graph_input ,
151
153
ignore_graph_output = not self .alloc_graph_output ,
154
+ ignore_mutable_buffers = not self .alloc_mutable_buffers ,
152
155
do_assertion = False ,
153
156
ignore_out_var_node = False ,
154
157
dedup = True ,
@@ -374,6 +377,7 @@ def collect_specs_from_nodes( # noqa: C901
374
377
graph_signature : Optional [ExportGraphSignature ] = None ,
375
378
ignore_graph_input : bool = False ,
376
379
ignore_graph_output : bool = False ,
380
+ ignore_mutable_buffers : bool = False ,
377
381
ignore_const : bool = True ,
378
382
ignore_out_var_node : bool = True ,
379
383
dedup : bool = True ,
@@ -414,6 +418,9 @@ def collect_specs_from_nodes( # noqa: C901
414
418
if _is_inplace_node (node ):
415
419
continue
416
420
421
+ if _is_mutable_buffer (node , graph_signature ) and ignore_mutable_buffers :
422
+ continue
423
+
417
424
if do_assertion :
418
425
internal_assert (
419
426
node .op in ("placeholder" , "output" )
@@ -469,6 +476,7 @@ def update_all_tensors_lifetime(
469
476
Set the lifetime for all the tensors encountered in the Fx graph.
470
477
"""
471
478
specs = set ()
479
+
472
480
for node_idx , node in enumerate (graph_module .graph .nodes ):
473
481
for spec in collect_specs_from_nodes (
474
482
filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
@@ -1053,6 +1061,7 @@ def apply_algo(
1053
1061
graph_signature : Optional [ExportGraphSignature ] = None ,
1054
1062
alloc_graph_input : bool = True ,
1055
1063
alloc_graph_output : bool = True ,
1064
+ alloc_mutable_buffers : bool = True ,
1056
1065
) -> List [int ]:
1057
1066
"""
1058
1067
Recursively apply algo to graph_module and its submodules for control flow.
@@ -1065,19 +1074,18 @@ def apply_algo(
1065
1074
storage with tensors in the outer module.
1066
1075
TODO: make these optimizations once we have some baseline working.
1067
1076
"""
1068
-
1069
1077
# Extract the nodes and their lifespans from the graph_module
1070
1078
# Difficult to just filter the list of specs returned by this due to
1071
1079
# how we flag trainable weights.
1072
1080
_ = update_all_tensors_lifetime (graph_module , graph_signature )
1073
-
1074
1081
# Filter specs based on alloc_graph_input and alloc_graph_output
1075
1082
specs = collect_specs_from_nodes (
1076
1083
graph_module .graph .nodes ,
1077
1084
graph_signature ,
1078
1085
do_assertion = False ,
1079
1086
ignore_graph_input = not alloc_graph_input ,
1080
1087
ignore_graph_output = not alloc_graph_output ,
1088
+ ignore_mutable_buffers = not alloc_mutable_buffers ,
1081
1089
)
1082
1090
1083
1091
# Get extra padding for XNNPACK if needed
0 commit comments