31
31
32
32
33
33
class ControlFlowGraph (object ):
34
- def __init__ (self , Program ):
34
+ def __init__ (self , Program , ops , forward_num ):
35
35
self ._program = Program
36
- self ._succesors = defaultdict (set )
37
- self ._presucessors = defaultdict (set )
36
+ self ._ops = ops
37
+ self ._forward_num = forward_num
38
+ self ._successors = defaultdict (set )
39
+ self ._presuccessors = defaultdict (set )
38
40
self ._uses = defaultdict (set )
39
41
self ._defs = defaultdict (set )
40
42
self ._live_in = defaultdict (set )
@@ -45,25 +47,16 @@ def _add_connections(self, connections):
45
47
self ._add (node1 , node2 )
46
48
47
49
def _add (self , node1 , node2 ):
48
- self ._succesors [node1 ].add (node2 )
49
- self ._presucessors [node2 ].add (node1 )
50
+ self ._successors [node1 ].add (node2 )
51
+ self ._presuccessors [node2 ].add (node1 )
50
52
51
53
def _build_graph (self ):
52
- program_desc = self ._program .get_desc ()
53
- block_size = program_desc .num_blocks ()
54
-
55
- # TODO(qijun) handle Program with if/while operators
56
- self .global_block_desc = program_desc .block (0 )
57
- self .op_size = self .global_block_desc .op_size ()
58
-
54
+ self .op_size = len (self ._ops )
59
55
op_node_connections = [(i , i + 1 ) for i in range (self .op_size - 1 )]
60
56
self ._add_connections (op_node_connections )
61
-
62
- self .ops = [self .global_block_desc .op (i ) for i in range (self .op_size )]
63
-
64
57
for i in range (self .op_size ):
65
- self ._uses [i ].update (self .ops [i ].input_arg_names ())
66
- self ._defs [i ].update (self .ops [i ].output_arg_names ())
58
+ self ._uses [i ].update (self ._ops [i ].input_arg_names ())
59
+ self ._defs [i ].update (self ._ops [i ].output_arg_names ())
67
60
68
61
def _update_graph (self , old_name , new_name , begin_idx = 0 ):
69
62
for i in range (begin_idx , self .op_size ):
@@ -103,7 +96,7 @@ def _dataflow_analyze(self):
103
96
live_out [i ] = set (self ._live_out [i ])
104
97
self ._live_in [i ] = self ._uses [i ] | (
105
98
self ._live_out [i ] - self ._defs [i ])
106
- for s in self ._succesors [i ]:
99
+ for s in self ._successors [i ]:
107
100
self ._live_out [i ] |= self ._live_in [s ]
108
101
109
102
if self ._reach_fixed_point (live_in , live_out ):
@@ -113,60 +106,147 @@ def _get_diff(self, a, b):
113
106
u = a & b
114
107
return a - u , b - u
115
108
109
+ def _has_var (self , block_desc , var_name , is_forward ):
110
+ if is_forward :
111
+ return block_desc .has_var (str (var_name ))
112
+ else :
113
+ return block_desc .has_var_recursive (str (var_name ))
114
+
115
+ def _find_var (self , block_desc , var_name , is_forward ):
116
+ if is_forward :
117
+ return block_desc .find_var (str (var_name ))
118
+ else :
119
+ return block_desc .find_var_recursive (str (var_name ))
120
+
116
121
def memory_optimize (self ):
122
+ def check_var_validity (block_desc , x , is_forward ):
123
+ if str (x ) == "@EMPTY@" :
124
+ return False
125
+ if not self ._has_var (block_desc , x , is_forward ):
126
+ return False
127
+ if self ._find_var (block_desc , x , is_forward ).persistable ():
128
+ return False
129
+ if self ._find_var (
130
+ block_desc , x ,
131
+ is_forward ).type () != core .VarDesc .VarType .LOD_TENSOR :
132
+ return False
133
+ return True
134
+
117
135
self ._build_graph ()
118
136
self ._dataflow_analyze ()
119
137
self .pool = []
120
138
for i in range (self .op_size ):
139
+ op = self ._ops [i ]
140
+ if op .type () == "while" or op .type () == "while_grad" :
141
+ continue
142
+ block_desc = op .block ()
143
+ is_forward = i < self ._forward_num
121
144
if self .pool :
122
- out_pair = [(x , self .global_block_desc .var (str (x )).shape ())
123
- for x in self ._defs [i ]]
145
+ defs_can_optimize = filter (
146
+ lambda x : check_var_validity (block_desc , x , is_forward ),
147
+ self ._defs [i ])
148
+ out_pair = [
149
+ (x , self ._find_var (block_desc , x , is_forward ).shape ())
150
+ for x in defs_can_optimize
151
+ ]
124
152
for x , x_shape in out_pair :
125
- if not self .global_block_desc . var ( str ( x )). persistable ( ):
126
- for index , cache_pair in enumerate ( self . pool ):
127
- cache_var = cache_pair [0 ]
128
- cache_shape = cache_pair [ 1 ]
129
- if x_shape == cache_shape :
130
- x_dtype = self .global_block_desc . var ( str (
131
- x ) ).dtype ()
132
- cache_dtype = self .global_block_desc . var (
133
- str ( cache_var ) ).dtype ()
153
+ for index , cache_pair in enumerate ( self .pool ):
154
+ cache_var = cache_pair [ 0 ]
155
+ cache_shape = cache_pair [1 ]
156
+ if x_shape == cache_shape :
157
+ if self . _has_var ( block_desc , cache_var , is_forward ) :
158
+ x_dtype = self ._find_var ( block_desc , x ,
159
+ is_forward ).dtype ()
160
+ cache_dtype = self ._find_var (
161
+ block_desc , cache_var , is_forward ).dtype ()
134
162
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
135
163
# and dtype_to_size[cache_dtype]
136
164
if x_dtype == cache_dtype :
137
- print (
138
- ( "Hit Cache !!!! cache pool index "
139
- "is %d, var name is %s, "
140
- "cached var name is %s, "
141
- "var shape is %s " ) %
142
- ( index , x , cache_var , str (cache_shape )))
165
+ print (( "Hit Cache !!!! cache pool index "
166
+ "is %d, var name is %s, "
167
+ "cached var name is %s, "
168
+ " var shape is %s " ) %
169
+ ( index , x , cache_var ,
170
+ str (cache_shape )))
143
171
self .pool .pop (index )
172
+ if x == cache_var :
173
+ break
144
174
_rename_arg_ (
145
- self .ops , x , cache_var , begin_idx = i )
146
- self ._program .current_block ( ).var ( str (
147
- x )).desc = self .global_block_desc . var (
148
- str ( cache_var ) )
175
+ self ._ops , x , cache_var , begin_idx = i )
176
+ self ._program .block ( block_desc . id ).var (
177
+ str ( x )).desc = self ._find_var (
178
+ block_desc , cache_var , is_forward )
149
179
self ._update_graph (
150
180
x , cache_var , begin_idx = i )
151
181
break
152
182
153
183
in_diff , out_diff = self ._get_diff (self ._live_in [i ],
154
184
self ._live_out [i ])
155
185
can_optimize = filter (
156
- lambda x : not self . global_block_desc . var ( str ( x )). persistable ( ),
186
+ lambda x : check_var_validity ( block_desc , x , is_forward ),
157
187
in_diff )
158
188
if can_optimize :
159
189
for var_name in can_optimize :
160
- self .pool .append (
161
- (var_name ,
162
- self .global_block_desc .var (str (var_name )).shape ()))
163
-
164
- def get_program (self ):
165
- return self ._program
190
+ self .pool .append ((var_name , self ._find_var (
191
+ block_desc , var_name , is_forward ).shape ()))
192
+
193
+
194
+ def get_cfgs (input_program ):
195
+ ops_list = []
196
+ pdesc = input_program .get_desc ()
197
+ block_desc = pdesc .block (0 )
198
+ op_size = block_desc .op_size ()
199
+ # Get global block ops
200
+ ops_list .append (([block_desc .op (i ) for i in range (op_size )], op_size ))
201
+
202
+ while_sub_block_ids = []
203
+ while_grad_sub_block_ids = []
204
+ while_pair = []
205
+
206
+ for i in range (op_size ):
207
+ op = block_desc .op (i )
208
+ if op .type () == "while" :
209
+ while_sub_block_ids .append (op .attr ("sub_block" ).id )
210
+ elif op .type () == "while_grad" :
211
+ while_grad_sub_block_ids .append (op .attr ("sub_block" ).id )
212
+
213
+ # Find while/while_grad block pair
214
+ for grad_id in while_grad_sub_block_ids :
215
+ parent_id = pdesc .block (grad_id ).parent
216
+ if parent_id in while_sub_block_ids :
217
+ while_pair .append ((parent_id , grad_id ))
218
+ while_sub_block_ids .remove (parent_id )
219
+
220
+ # Get while/while_grad block ops
221
+ for parent_id , grad_id in while_pair :
222
+ while_block_ops = []
223
+ while_block = pdesc .block (parent_id )
224
+ while_block_op_size = while_block .op_size ()
225
+ for i in range (while_block_op_size ):
226
+ while_block_ops .append (while_block .op (i ))
227
+
228
+ while_grad_block = pdesc .block (grad_id )
229
+ while_grad_block_op_size = while_grad_block .op_size ()
230
+ for i in range (while_grad_block_op_size ):
231
+ while_block_ops .append (while_grad_block .op (i ))
232
+
233
+ ops_list .append ((while_block_ops , while_block_op_size ))
234
+
235
+ # Process rest while block ops
236
+ for parent_id in while_sub_block_ids :
237
+ while_block_ops = []
238
+ while_block = pdesc .block (parent_id )
239
+ while_block_op_size = while_block .op_size ()
240
+ for i in range (while_block_op_size ):
241
+ while_block_ops .append (while_block .op (i ))
242
+
243
+ ops_list .append ((while_block_ops , while_block_op_size ))
244
+
245
+ cfgs = [ControlFlowGraph (input_program , i , j ) for i , j in ops_list ]
246
+ return cfgs
166
247
167
248
168
249
def memory_optimize (input_program ):
169
- graph = ControlFlowGraph (input_program )
170
- graph .memory_optimize ()
171
- result_program = graph .get_program ()
172
- return result_program
250
+ cfgs = get_cfgs (input_program )
251
+ for cfg in cfgs :
252
+ cfg .memory_optimize ()
0 commit comments