Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit c53e45c

Browse files
authored
Added autominifier for FX graphs (#330)
* Added autominifier * fixed some stuff
1 parent d3477c1 commit c53e45c

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

functorch/_src/fx_minifier.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import torch.fx as fx
2+
import copy
3+
import torch
4+
import math
5+
6+
class ConcreteProp(torch.fx.Interpreter):
7+
def run_node(self, n):
8+
result = super().run_node(n)
9+
10+
found_tensor = False
11+
12+
def extract_tensor_meta(obj):
13+
if isinstance(obj, torch.Tensor):
14+
nonlocal found_tensor
15+
found_tensor = True
16+
return obj
17+
else:
18+
return obj
19+
20+
from torch.fx.node import map_aggregate
21+
concrete_value = map_aggregate(result, extract_tensor_meta)
22+
if found_tensor:
23+
n.meta['concrete_value'] = concrete_value
24+
return result
25+
26+
def propagate(self, *args):
27+
return super().run(*args)
28+
29+
def _get_placeholders(graph):
30+
return list(filter(lambda x: x.op == 'placeholder', graph.nodes))
31+
32+
# inplace modifies node/inps
33+
def _convert_node_to_placeholder(node, inps):
34+
node.op = 'placeholder'
35+
node.args = ()
36+
node.target = node.name
37+
concrete_val = node.meta['concrete_value']
38+
if isinstance(concrete_val, torch.Tensor):
39+
inps.append(concrete_val)
40+
else:
41+
inps.append(torch.zeros(()))
42+
for tuple_user in list(node.users):
43+
_convert_node_to_placeholder(tuple_user, inps)
44+
45+
def minimizer(fail_f: fx.GraphModule, inps, pass_checker):
46+
"""
47+
Minimizes a FX graph with given inputs, such that the resulting FX graph still fails the pass_checker.
48+
49+
Does 2 main strategies:
50+
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
51+
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, tries replacing quarter of the graph, etc.
52+
53+
>>> failing_function = fx.symbolic_trace(f)
54+
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
55+
"""
56+
failing_graph = fail_f.graph
57+
cur_size = len(failing_graph.nodes)
58+
59+
def graph_passes(graph, inps):
60+
graph.lint()
61+
mod = fx.GraphModule(fail_f, graph)
62+
return pass_checker(mod, inps)
63+
64+
ConcreteProp(fail_f).propagate(*inps)
65+
if graph_passes(failing_graph, inps):
66+
raise RuntimeError("Input graph did not fail the tester")
67+
print(f"Started off with {cur_size} nodes")
68+
69+
def remove_suffix(cur_graph, cur_inps):
70+
print("Strategy: Remove suffix")
71+
assert not graph_passes(cur_graph, cur_inps)
72+
gap = 2**math.floor(math.log2(len(cur_graph.nodes)))
73+
tested = set()
74+
while gap >= 1:
75+
print(f"search gap: {gap}: ", end='')
76+
new_graph = fx.Graph()
77+
env = {}
78+
for idx, node in enumerate(cur_graph.nodes):
79+
new_node = new_graph.node_copy(node, lambda x: env[x])
80+
if node.op not in ['placeholder', 'output']:
81+
if idx % gap == 0 and idx not in tested:
82+
print(f"{idx}", end=',')
83+
output_node = new_graph.output(([new_node],))
84+
if not graph_passes(new_graph, cur_inps) and len(new_graph.nodes) < len(cur_graph.nodes):
85+
print()
86+
print(f"SUCCESS: Found failing case with first {idx} nodes")
87+
return (new_graph, cur_inps), True
88+
else:
89+
tested.add(idx)
90+
new_graph.erase_node(output_node)
91+
env[node] = new_node
92+
gap //= 2
93+
print()
94+
print("FAIL: Could not remove suffix")
95+
return (cur_graph, cur_inps), False
96+
97+
98+
def remove_unused_inputs(cur_graph, cur_inps):
99+
print("Strategy: Remove unused inputs")
100+
assert not graph_passes(cur_graph, cur_inps)
101+
ph_nodes = _get_placeholders(cur_graph)
102+
if len(ph_nodes) != len(cur_inps):
103+
print(cur_graph)
104+
print(len(cur_inps))
105+
assert len(ph_nodes) == len(cur_inps)
106+
107+
new_inps = []
108+
for idx in range(len(ph_nodes)):
109+
if len(ph_nodes[idx].users) == 0:
110+
cur_graph.erase_node(ph_nodes[idx])
111+
else:
112+
new_inps.append(cur_inps[idx])
113+
114+
if len(new_inps) < len(cur_inps):
115+
print(f"SUCCESS: Went from {len(cur_inps)} inputs to {len(new_inps)} inputs")
116+
return (cur_graph, new_inps), True
117+
else:
118+
print("FAIL: Could not remove inputs")
119+
return (cur_graph, new_inps), False
120+
121+
def consolidate_placeholders(cur_graph):
122+
new_graph = fx.Graph()
123+
env = {}
124+
for node in cur_graph.nodes:
125+
if node.op == 'placeholder':
126+
new_node = new_graph.node_copy(node, lambda x: env[x])
127+
env[node] = new_node
128+
129+
for node in cur_graph.nodes:
130+
if node.op != 'placeholder':
131+
new_node = new_graph.node_copy(node, lambda x: env[x])
132+
env[node] = new_node
133+
return new_graph
134+
135+
def delta_debugging(cur_graph: fx.Graph, cur_inps):
136+
print("Strategy: Delta Debugging")
137+
assert not graph_passes(cur_graph, cur_inps)
138+
starting_placeholders = len(_get_placeholders(cur_graph))
139+
num_nodes = len(cur_graph.nodes)
140+
gap = int(2**math.floor(math.log2(num_nodes)))
141+
while gap >= 1:
142+
print(f"Searching with gap of {gap}")
143+
for start_range in range(0, num_nodes, gap):
144+
is_removing = False
145+
new_graph = copy.deepcopy(cur_graph)
146+
new_inps = cur_inps[:]
147+
for idx in range(start_range, min(num_nodes, start_range + gap)):
148+
new_node = list(new_graph.nodes)[idx]
149+
if new_node.op not in ['placeholder', 'output']:
150+
is_removing = True
151+
_convert_node_to_placeholder(new_node, new_inps)
152+
if not is_removing:
153+
continue
154+
new_graph = consolidate_placeholders(new_graph)
155+
if not graph_passes(new_graph, new_inps):
156+
print(f"SUCCESS: Went from {starting_placeholders} placeholders to {len(_get_placeholders(new_graph))}")
157+
return (new_graph, new_inps), True
158+
gap //= 2
159+
160+
print("FAIL: Could not remove prefix")
161+
return (cur_graph, inps), False
162+
163+
164+
while True:
165+
any_succeeded = False
166+
for strategy in [remove_suffix, remove_unused_inputs, delta_debugging, remove_unused_inputs]:
167+
print(f"###################")
168+
print(f"Current size: {len(failing_graph.nodes)}")
169+
print(f"###################")
170+
out = strategy(copy.deepcopy(failing_graph), inps[:])
171+
(cur_graph, cur_inps), succeeded = out
172+
if succeeded:
173+
failing_graph = cur_graph
174+
failing_graph.eliminate_dead_code()
175+
inps = cur_inps
176+
any_succeeded = True
177+
print()
178+
179+
if not any_succeeded:
180+
break
181+
failing_fx = fx.GraphModule(fail_f, failing_graph)
182+
print(failing_fx.code)
183+
print([i.shape for i in inps])
184+
return failing_fx, inps

functorch/compile/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .._src.python_key import nnc_jit, make_nnc, pythonkey_decompose
44
from .._src.decompositions import register_decomposition, decomposition_table
55
from .._src.nnc_compile import nnc_compile, get_ops
6+
from .._src.fx_minifier import minimizer
67
from .._src.aot_autograd import (
78
aot_function,
89
aot_module,

0 commit comments

Comments
 (0)