@@ -42,62 +42,60 @@ def _convert_node_to_placeholder(node, inps):
42
42
for tuple_user in list (node .users ):
43
43
_convert_node_to_placeholder (tuple_user , inps )
44
44
45
- def minimizer (fail_f : fx .GraphModule , inps , pass_checker ):
45
+ def minimizer (fail_f : fx .GraphModule , inps , module_fails ):
46
46
"""
47
- Minimizes a FX graph with given inputs, such that the resulting FX graph still fails the pass_checker .
47
+ Minimizes a FX graph with given inputs, such that the resulting FX graph still returns True for module_fails .
48
48
49
49
Does 2 main strategies:
50
50
1. Truncates suffix: Removes some suffix from the graph and sets a new output.
51
51
2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, tries replacing quarter of the graph, etc.
52
52
53
53
>>> failing_function = fx.symbolic_trace(f)
54
54
>>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
55
+
56
+ note: module_fails returns True if it fails.
55
57
"""
56
58
failing_graph = fail_f .graph
57
59
cur_size = len (failing_graph .nodes )
58
60
59
- def graph_passes (graph , inps ):
60
- graph .lint ()
61
+ def graph_fails (graph , inps ):
61
62
mod = fx .GraphModule (fail_f , graph )
62
- return pass_checker (mod , inps )
63
+ mod .graph .lint ()
64
+ return module_fails (mod , inps )
63
65
64
66
ConcreteProp (fail_f ).propagate (* inps )
65
- if graph_passes (failing_graph , inps ):
67
+ if not graph_fails (failing_graph , inps ):
66
68
raise RuntimeError ("Input graph did not fail the tester" )
67
69
print (f"Started off with { cur_size } nodes" )
68
70
69
71
def remove_suffix (cur_graph , cur_inps ):
70
72
print ("Strategy: Remove suffix" )
71
- assert not graph_passes (cur_graph , cur_inps )
73
+ assert graph_fails (cur_graph , cur_inps )
72
74
gap = 2 ** math .floor (math .log2 (len (cur_graph .nodes )))
73
75
tested = set ()
74
76
while gap >= 1 :
75
- print (f"search gap: { gap } : " , end = '' )
76
77
new_graph = fx .Graph ()
77
78
env = {}
78
79
for idx , node in enumerate (cur_graph .nodes ):
79
80
new_node = new_graph .node_copy (node , lambda x : env [x ])
80
81
if node .op not in ['placeholder' , 'output' ]:
81
82
if idx % gap == 0 and idx not in tested :
82
- print (f"{ idx } " , end = ',' )
83
- output_node = new_graph .output (([new_node ],))
84
- if not graph_passes (new_graph , cur_inps ) and len (new_graph .nodes ) < len (cur_graph .nodes ):
83
+ output_node = new_graph .output ((new_node ,))
84
+ if graph_fails (new_graph , cur_inps ) and len (new_graph .nodes ) < len (cur_graph .nodes ):
85
85
print ()
86
- print (f"SUCCESS: Found failing case with first { idx } nodes" )
86
+ print (f"SUCCESS: Removed [ { idx } : { len ( cur_graph . nodes ) } ) " )
87
87
return (new_graph , cur_inps ), True
88
88
else :
89
89
tested .add (idx )
90
90
new_graph .erase_node (output_node )
91
91
env [node ] = new_node
92
92
gap //= 2
93
- print ()
94
93
print ("FAIL: Could not remove suffix" )
95
94
return (cur_graph , cur_inps ), False
96
95
97
96
98
97
def remove_unused_inputs (cur_graph , cur_inps ):
99
- print ("Strategy: Remove unused inputs" )
100
- assert not graph_passes (cur_graph , cur_inps )
98
+ assert graph_fails (cur_graph , cur_inps )
101
99
ph_nodes = _get_placeholders (cur_graph )
102
100
if len (ph_nodes ) != len (cur_inps ):
103
101
print (cur_graph )
@@ -111,13 +109,22 @@ def remove_unused_inputs(cur_graph, cur_inps):
111
109
else :
112
110
new_inps .append (cur_inps [idx ])
113
111
114
- if len (new_inps ) < len (cur_inps ):
112
+ if len (new_inps ) < len (cur_inps ) and graph_fails (cur_graph , new_inps ):
113
+ print ("Strategy: Remove unused inputs" )
115
114
print (f"SUCCESS: Went from { len (cur_inps )} inputs to { len (new_inps )} inputs" )
116
115
return (cur_graph , new_inps ), True
117
116
else :
118
- print ("FAIL: Could not remove inputs" )
119
117
return (cur_graph , new_inps ), False
120
118
119
+ def eliminate_dead_code (cur_graph , cur_inps ):
120
+ orig_size = len (cur_graph .nodes )
121
+ if cur_graph .eliminate_dead_code () and graph_fails (cur_graph , cur_inps ):
122
+ print ("Strategy: Eliminate dead code" )
123
+ print (f"SUCCESS: Went from { orig_size } nodes to { len (cur_graph .nodes )} nodes" )
124
+ return (cur_graph , cur_inps ), True
125
+ else :
126
+ return (cur_graph , cur_inps ), False
127
+
121
128
def consolidate_placeholders (cur_graph ):
122
129
new_graph = fx .Graph ()
123
130
env = {}
@@ -134,47 +141,49 @@ def consolidate_placeholders(cur_graph):
134
141
135
142
def delta_debugging (cur_graph : fx .Graph , cur_inps ):
136
143
print ("Strategy: Delta Debugging" )
137
- assert not graph_passes (cur_graph , cur_inps )
144
+ assert graph_fails (cur_graph , cur_inps )
138
145
starting_placeholders = len (_get_placeholders (cur_graph ))
139
146
num_nodes = len (cur_graph .nodes )
140
147
gap = int (2 ** math .floor (math .log2 (num_nodes )))
141
148
while gap >= 1 :
142
- print (f"Searching with gap of { gap } " )
143
149
for start_range in range (0 , num_nodes , gap ):
144
150
is_removing = False
145
151
new_graph = copy .deepcopy (cur_graph )
146
152
new_inps = cur_inps [:]
147
- for idx in range (start_range , min (num_nodes , start_range + gap )):
153
+ end_range = min (num_nodes , start_range + gap )
154
+ for idx in range (start_range , end_range ):
148
155
new_node = list (new_graph .nodes )[idx ]
149
156
if new_node .op not in ['placeholder' , 'output' ]:
150
157
is_removing = True
151
158
_convert_node_to_placeholder (new_node , new_inps )
152
159
if not is_removing :
153
160
continue
154
161
new_graph = consolidate_placeholders (new_graph )
155
- if not graph_passes (new_graph , new_inps ):
156
- print (f"SUCCESS: Went from { starting_placeholders } placeholders to { len (_get_placeholders (new_graph ))} " )
162
+ if graph_fails (new_graph , new_inps ):
163
+ print (f"SUCCESS: Removed ( { start_range } : { end_range } ] - Went from { starting_placeholders } placeholders to { len (_get_placeholders (new_graph ))} " )
157
164
return (new_graph , new_inps ), True
158
165
gap //= 2
159
166
160
167
print ("FAIL: Could not remove prefix" )
161
168
return (cur_graph , inps ), False
162
169
163
170
171
+ print (f"###################" )
172
+ print (f"Current size: { len (failing_graph .nodes )} " )
173
+ print (f"###################" )
164
174
while True :
165
175
any_succeeded = False
166
- for strategy in [remove_suffix , remove_unused_inputs , delta_debugging , remove_unused_inputs ]:
167
- print (f"###################" )
168
- print (f"Current size: { len (failing_graph .nodes )} " )
169
- print (f"###################" )
176
+ for strategy in [remove_suffix , eliminate_dead_code , remove_unused_inputs , delta_debugging , eliminate_dead_code , remove_unused_inputs ]:
170
177
out = strategy (copy .deepcopy (failing_graph ), inps [:])
171
178
(cur_graph , cur_inps ), succeeded = out
172
179
if succeeded :
180
+ print ()
181
+ print (f"###################" )
182
+ print (f"Current size: { len (cur_graph .nodes )} " )
183
+ print (f"###################" )
173
184
failing_graph = cur_graph
174
- failing_graph .eliminate_dead_code ()
175
185
inps = cur_inps
176
186
any_succeeded = True
177
- print ()
178
187
179
188
if not any_succeeded :
180
189
break
0 commit comments