Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 97822c1

Browse files
committed
added some tests and stuff to minimizer
1 parent a53be5c commit 97822c1

File tree

2 files changed

+90
-28
lines changed

2 files changed

+90
-28
lines changed

functorch/_src/fx_minifier.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,62 +42,60 @@ def _convert_node_to_placeholder(node, inps):
4242
for tuple_user in list(node.users):
4343
_convert_node_to_placeholder(tuple_user, inps)
4444

45-
def minimizer(fail_f: fx.GraphModule, inps, pass_checker):
45+
def minimizer(fail_f: fx.GraphModule, inps, module_fails):
4646
"""
47-
Minimizes a FX graph with given inputs, such that the resulting FX graph still fails the pass_checker.
47+
Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails.
4848
4949
Does 2 main strategies:
5050
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
5151
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, tries replacing quarter of the graph, etc.
5252
5353
>>> failing_function = fx.symbolic_trace(f)
5454
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
55+
56+
note: module_fails returns True if it fails.
5557
"""
5658
failing_graph = fail_f.graph
5759
cur_size = len(failing_graph.nodes)
5860

59-
def graph_passes(graph, inps):
60-
graph.lint()
61+
def graph_fails(graph, inps):
6162
mod = fx.GraphModule(fail_f, graph)
62-
return pass_checker(mod, inps)
63+
mod.graph.lint()
64+
return module_fails(mod, inps)
6365

6466
ConcreteProp(fail_f).propagate(*inps)
65-
if graph_passes(failing_graph, inps):
67+
if not graph_fails(failing_graph, inps):
6668
raise RuntimeError("Input graph did not fail the tester")
6769
print(f"Started off with {cur_size} nodes")
6870

6971
def remove_suffix(cur_graph, cur_inps):
7072
print("Strategy: Remove suffix")
71-
assert not graph_passes(cur_graph, cur_inps)
73+
assert graph_fails(cur_graph, cur_inps)
7274
gap = 2**math.floor(math.log2(len(cur_graph.nodes)))
7375
tested = set()
7476
while gap >= 1:
75-
print(f"search gap: {gap}: ", end='')
7677
new_graph = fx.Graph()
7778
env = {}
7879
for idx, node in enumerate(cur_graph.nodes):
7980
new_node = new_graph.node_copy(node, lambda x: env[x])
8081
if node.op not in ['placeholder', 'output']:
8182
if idx % gap == 0 and idx not in tested:
82-
print(f"{idx}", end=',')
83-
output_node = new_graph.output(([new_node],))
84-
if not graph_passes(new_graph, cur_inps) and len(new_graph.nodes) < len(cur_graph.nodes):
83+
output_node = new_graph.output((new_node,))
84+
if graph_fails(new_graph, cur_inps) and len(new_graph.nodes) < len(cur_graph.nodes):
8585
print()
86-
print(f"SUCCESS: Found failing case with first {idx} nodes")
86+
print(f"SUCCESS: Removed [{idx}:{len(cur_graph.nodes)})")
8787
return (new_graph, cur_inps), True
8888
else:
8989
tested.add(idx)
9090
new_graph.erase_node(output_node)
9191
env[node] = new_node
9292
gap //= 2
93-
print()
9493
print("FAIL: Could not remove suffix")
9594
return (cur_graph, cur_inps), False
9695

9796

9897
def remove_unused_inputs(cur_graph, cur_inps):
99-
print("Strategy: Remove unused inputs")
100-
assert not graph_passes(cur_graph, cur_inps)
98+
assert graph_fails(cur_graph, cur_inps)
10199
ph_nodes = _get_placeholders(cur_graph)
102100
if len(ph_nodes) != len(cur_inps):
103101
print(cur_graph)
@@ -111,13 +109,22 @@ def remove_unused_inputs(cur_graph, cur_inps):
111109
else:
112110
new_inps.append(cur_inps[idx])
113111

114-
if len(new_inps) < len(cur_inps):
112+
if len(new_inps) < len(cur_inps) and graph_fails(cur_graph, new_inps):
113+
print("Strategy: Remove unused inputs")
115114
print(f"SUCCESS: Went from {len(cur_inps)} inputs to {len(new_inps)} inputs")
116115
return (cur_graph, new_inps), True
117116
else:
118-
print("FAIL: Could not remove inputs")
119117
return (cur_graph, new_inps), False
120118

119+
def eliminate_dead_code(cur_graph, cur_inps):
120+
orig_size = len(cur_graph.nodes)
121+
if cur_graph.eliminate_dead_code() and graph_fails(cur_graph, cur_inps):
122+
print("Strategy: Eliminate dead code")
123+
print(f"SUCCESS: Went from {orig_size} nodes to {len(cur_graph.nodes)} nodes")
124+
return (cur_graph, cur_inps), True
125+
else:
126+
return (cur_graph, cur_inps), False
127+
121128
def consolidate_placeholders(cur_graph):
122129
new_graph = fx.Graph()
123130
env = {}
@@ -134,47 +141,49 @@ def consolidate_placeholders(cur_graph):
134141

135142
def delta_debugging(cur_graph: fx.Graph, cur_inps):
136143
print("Strategy: Delta Debugging")
137-
assert not graph_passes(cur_graph, cur_inps)
144+
assert graph_fails(cur_graph, cur_inps)
138145
starting_placeholders = len(_get_placeholders(cur_graph))
139146
num_nodes = len(cur_graph.nodes)
140147
gap = int(2**math.floor(math.log2(num_nodes)))
141148
while gap >= 1:
142-
print(f"Searching with gap of {gap}")
143149
for start_range in range(0, num_nodes, gap):
144150
is_removing = False
145151
new_graph = copy.deepcopy(cur_graph)
146152
new_inps = cur_inps[:]
147-
for idx in range(start_range, min(num_nodes, start_range + gap)):
153+
end_range = min(num_nodes, start_range + gap)
154+
for idx in range(start_range, end_range):
148155
new_node = list(new_graph.nodes)[idx]
149156
if new_node.op not in ['placeholder', 'output']:
150157
is_removing = True
151158
_convert_node_to_placeholder(new_node, new_inps)
152159
if not is_removing:
153160
continue
154161
new_graph = consolidate_placeholders(new_graph)
155-
if not graph_passes(new_graph, new_inps):
156-
print(f"SUCCESS: Went from {starting_placeholders} placeholders to {len(_get_placeholders(new_graph))}")
162+
if graph_fails(new_graph, new_inps):
163+
print(f"SUCCESS: Removed ({start_range}:{end_range}] - Went from {starting_placeholders} placeholders to {len(_get_placeholders(new_graph))}")
157164
return (new_graph, new_inps), True
158165
gap //= 2
159166

160167
print("FAIL: Could not remove prefix")
161168
return (cur_graph, inps), False
162169

163170

171+
print(f"###################")
172+
print(f"Current size: {len(failing_graph.nodes)}")
173+
print(f"###################")
164174
while True:
165175
any_succeeded = False
166-
for strategy in [remove_suffix, remove_unused_inputs, delta_debugging, remove_unused_inputs]:
167-
print(f"###################")
168-
print(f"Current size: {len(failing_graph.nodes)}")
169-
print(f"###################")
176+
for strategy in [remove_suffix, eliminate_dead_code, remove_unused_inputs, delta_debugging, eliminate_dead_code, remove_unused_inputs]:
170177
out = strategy(copy.deepcopy(failing_graph), inps[:])
171178
(cur_graph, cur_inps), succeeded = out
172179
if succeeded:
180+
print()
181+
print(f"###################")
182+
print(f"Current size: {len(cur_graph.nodes)}")
183+
print(f"###################")
173184
failing_graph = cur_graph
174-
failing_graph.eliminate_dead_code()
175185
inps = cur_inps
176186
any_succeeded = True
177-
print()
178187

179188
if not any_succeeded:
180189
break

test/test_minifier.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
from torch.nn import functional as F
3+
from functorch.compile import minimizer
4+
from functorch import make_fx
5+
from torch.testing._internal.common_utils import TestCase, run_tests
6+
from typing import Callable
7+
8+
9+
class TestMinifier(TestCase):
10+
def test_has_mul_minifier(self):
11+
def failing_f(x, y):
12+
y = y / 3
13+
x = x + 3
14+
x = x * y
15+
return x + y
16+
inps = [torch.randn(3), torch.randn(3)]
17+
failing_f = make_fx(failing_f)(*inps)
18+
19+
def pass_checker(fx_g, inps):
20+
return (torch.ops.aten.mul in set([i.target for i in fx_g.graph.nodes]))
21+
22+
min_f, inps = minimizer(failing_f, inps, pass_checker)
23+
assert len(min_f.graph.nodes) == 4
24+
assert len(inps) == 2
25+
26+
def test_has_add_mul(self):
27+
def failing_f(x):
28+
x = x * 3
29+
x = x + 5
30+
x = x.cos()
31+
zero = x - x
32+
result = zero / zero
33+
result = result + 3
34+
return (result * 2,)
35+
36+
inps = [torch.randn(3)]
37+
failing_f = make_fx(failing_f)(*inps)
38+
39+
def pass_checker(fx_g, inps):
40+
# Basically, make sure none of the inputs are nans
41+
for i in inps:
42+
if torch.isnan(i).any():
43+
return False
44+
return torch.isnan(fx_g(*inps)[0]).any()
45+
46+
min_f, inps = minimizer(failing_f, inps, pass_checker)
47+
import pdb; pdb.set_trace()
48+
assert len(min_f.graph.nodes) == 3
49+
assert len(inps) == 1
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)