1
+ import torch .fx as fx
2
+ import copy
3
+ import torch
4
+ import math
5
+
6
+ class ConcreteProp (torch .fx .Interpreter ):
7
+ def run_node (self , n ):
8
+ result = super ().run_node (n )
9
+
10
+ found_tensor = False
11
+
12
+ def extract_tensor_meta (obj ):
13
+ if isinstance (obj , torch .Tensor ):
14
+ nonlocal found_tensor
15
+ found_tensor = True
16
+ return obj
17
+ else :
18
+ return obj
19
+
20
+ from torch .fx .node import map_aggregate
21
+ concrete_value = map_aggregate (result , extract_tensor_meta )
22
+ if found_tensor :
23
+ n .meta ['concrete_value' ] = concrete_value
24
+ return result
25
+
26
+ def propagate (self , * args ):
27
+ return super ().run (* args )
28
+
29
+ def _get_placeholders (graph ):
30
+ return list (filter (lambda x : x .op == 'placeholder' , graph .nodes ))
31
+
32
+ # inplace modifies node/inps
33
+ def _convert_node_to_placeholder (node , inps ):
34
+ node .op = 'placeholder'
35
+ node .args = ()
36
+ node .target = node .name
37
+ concrete_val = node .meta ['concrete_value' ]
38
+ if isinstance (concrete_val , torch .Tensor ):
39
+ inps .append (concrete_val )
40
+ else :
41
+ inps .append (torch .zeros (()))
42
+ for tuple_user in list (node .users ):
43
+ _convert_node_to_placeholder (tuple_user , inps )
44
+
45
+ def minimizer (fail_f : fx .GraphModule , inps , pass_checker ):
46
+ """
47
+ Minimizes a FX graph with given inputs, such that the resulting FX graph still fails the pass_checker.
48
+
49
+ Does 2 main strategies:
50
+ 1. Truncates suffix: Removes some suffix from the graph and sets a new output.
51
+ 2. Delta Debugging: Tries replacing half of the graph with inputs. If fails, tries replacing quarter of the graph, etc.
52
+
53
+ >>> failing_function = fx.symbolic_trace(f)
54
+ >>> minimize(failing_function, [torch.randn(5)], lambda fx_g, inps: fx_g(*inps))
55
+ """
56
+ failing_graph = fail_f .graph
57
+ cur_size = len (failing_graph .nodes )
58
+
59
+ def graph_passes (graph , inps ):
60
+ graph .lint ()
61
+ mod = fx .GraphModule (fail_f , graph )
62
+ return pass_checker (mod , inps )
63
+
64
+ ConcreteProp (fail_f ).propagate (* inps )
65
+ if graph_passes (failing_graph , inps ):
66
+ raise RuntimeError ("Input graph did not fail the tester" )
67
+ print (f"Started off with { cur_size } nodes" )
68
+
69
+ def remove_suffix (cur_graph , cur_inps ):
70
+ print ("Strategy: Remove suffix" )
71
+ assert not graph_passes (cur_graph , cur_inps )
72
+ gap = 2 ** math .floor (math .log2 (len (cur_graph .nodes )))
73
+ tested = set ()
74
+ while gap >= 1 :
75
+ print (f"search gap: { gap } : " , end = '' )
76
+ new_graph = fx .Graph ()
77
+ env = {}
78
+ for idx , node in enumerate (cur_graph .nodes ):
79
+ new_node = new_graph .node_copy (node , lambda x : env [x ])
80
+ if node .op not in ['placeholder' , 'output' ]:
81
+ if idx % gap == 0 and idx not in tested :
82
+ print (f"{ idx } " , end = ',' )
83
+ output_node = new_graph .output (([new_node ],))
84
+ if not graph_passes (new_graph , cur_inps ) and len (new_graph .nodes ) < len (cur_graph .nodes ):
85
+ print ()
86
+ print (f"SUCCESS: Found failing case with first { idx } nodes" )
87
+ return (new_graph , cur_inps ), True
88
+ else :
89
+ tested .add (idx )
90
+ new_graph .erase_node (output_node )
91
+ env [node ] = new_node
92
+ gap //= 2
93
+ print ()
94
+ print ("FAIL: Could not remove suffix" )
95
+ return (cur_graph , cur_inps ), False
96
+
97
+
98
+ def remove_unused_inputs (cur_graph , cur_inps ):
99
+ print ("Strategy: Remove unused inputs" )
100
+ assert not graph_passes (cur_graph , cur_inps )
101
+ ph_nodes = _get_placeholders (cur_graph )
102
+ if len (ph_nodes ) != len (cur_inps ):
103
+ print (cur_graph )
104
+ print (len (cur_inps ))
105
+ assert len (ph_nodes ) == len (cur_inps )
106
+
107
+ new_inps = []
108
+ for idx in range (len (ph_nodes )):
109
+ if len (ph_nodes [idx ].users ) == 0 :
110
+ cur_graph .erase_node (ph_nodes [idx ])
111
+ else :
112
+ new_inps .append (cur_inps [idx ])
113
+
114
+ if len (new_inps ) < len (cur_inps ):
115
+ print (f"SUCCESS: Went from { len (cur_inps )} inputs to { len (new_inps )} inputs" )
116
+ return (cur_graph , new_inps ), True
117
+ else :
118
+ print ("FAIL: Could not remove inputs" )
119
+ return (cur_graph , new_inps ), False
120
+
121
+ def consolidate_placeholders (cur_graph ):
122
+ new_graph = fx .Graph ()
123
+ env = {}
124
+ for node in cur_graph .nodes :
125
+ if node .op == 'placeholder' :
126
+ new_node = new_graph .node_copy (node , lambda x : env [x ])
127
+ env [node ] = new_node
128
+
129
+ for node in cur_graph .nodes :
130
+ if node .op != 'placeholder' :
131
+ new_node = new_graph .node_copy (node , lambda x : env [x ])
132
+ env [node ] = new_node
133
+ return new_graph
134
+
135
+ def delta_debugging (cur_graph : fx .Graph , cur_inps ):
136
+ print ("Strategy: Delta Debugging" )
137
+ assert not graph_passes (cur_graph , cur_inps )
138
+ starting_placeholders = len (_get_placeholders (cur_graph ))
139
+ num_nodes = len (cur_graph .nodes )
140
+ gap = int (2 ** math .floor (math .log2 (num_nodes )))
141
+ while gap >= 1 :
142
+ print (f"Searching with gap of { gap } " )
143
+ for start_range in range (0 , num_nodes , gap ):
144
+ is_removing = False
145
+ new_graph = copy .deepcopy (cur_graph )
146
+ new_inps = cur_inps [:]
147
+ for idx in range (start_range , min (num_nodes , start_range + gap )):
148
+ new_node = list (new_graph .nodes )[idx ]
149
+ if new_node .op not in ['placeholder' , 'output' ]:
150
+ is_removing = True
151
+ _convert_node_to_placeholder (new_node , new_inps )
152
+ if not is_removing :
153
+ continue
154
+ new_graph = consolidate_placeholders (new_graph )
155
+ if not graph_passes (new_graph , new_inps ):
156
+ print (f"SUCCESS: Went from { starting_placeholders } placeholders to { len (_get_placeholders (new_graph ))} " )
157
+ return (new_graph , new_inps ), True
158
+ gap //= 2
159
+
160
+ print ("FAIL: Could not remove prefix" )
161
+ return (cur_graph , inps ), False
162
+
163
+
164
+ while True :
165
+ any_succeeded = False
166
+ for strategy in [remove_suffix , remove_unused_inputs , delta_debugging , remove_unused_inputs ]:
167
+ print (f"###################" )
168
+ print (f"Current size: { len (failing_graph .nodes )} " )
169
+ print (f"###################" )
170
+ out = strategy (copy .deepcopy (failing_graph ), inps [:])
171
+ (cur_graph , cur_inps ), succeeded = out
172
+ if succeeded :
173
+ failing_graph = cur_graph
174
+ failing_graph .eliminate_dead_code ()
175
+ inps = cur_inps
176
+ any_succeeded = True
177
+ print ()
178
+
179
+ if not any_succeeded :
180
+ break
181
+ failing_fx = fx .GraphModule (fail_f , failing_graph )
182
+ print (failing_fx .code )
183
+ print ([i .shape for i in inps ])
184
+ return failing_fx , inps
0 commit comments