29
29
core .VarDesc .VarType .BOOL : 1
30
30
}
31
31
32
+ sub_block_ops = ["while" , "while_grad" , "parallel_do" , "parallel_do_grad" ]
33
+
32
34
33
35
class ControlFlowGraph (object ):
34
36
def __init__ (self , Program , ops , forward_num , skip_opt ):
@@ -141,7 +143,7 @@ def check_var_validity(block_desc, x, is_forward):
141
143
self .pool = []
142
144
for i in range (self .op_size ):
143
145
op = self ._ops [i ]
144
- if op .type () == "while" or op . type () == "while_grad" :
146
+ if op .type () in sub_block_ops :
145
147
continue
146
148
block_desc = op .block ()
147
149
is_forward = i < self ._forward_num
@@ -198,67 +200,75 @@ def check_var_validity(block_desc, x, is_forward):
198
200
block_desc , var_name , is_forward ).shape ()))
199
201
200
202
201
- def get_cfgs ( input_program ):
203
+ def _process_sub_block_pair ( pdesc , sub_block_pair ):
202
204
ops_list = []
203
- pdesc = input_program .get_desc ()
204
205
block_desc = pdesc .block (0 )
205
206
op_size = block_desc .op_size ()
206
- # Get global block ops
207
- ops_list .append (
208
- ([block_desc .op (i ) for i in range (op_size )], op_size , set ()))
209
-
210
- while_sub_block_ids = []
211
- while_grad_sub_block_ids = []
212
- while_block_id_pair = []
213
- while_op_dict = {}
207
+ for fwd_op , bwd_op in sub_block_pair :
208
+ sub_block_ids = []
209
+ grad_sub_block_ids = []
210
+ sub_block_id_pair = []
211
+ sub_op_dict = {}
212
+ for i in range (op_size ):
213
+ op = block_desc .op (i )
214
+ if op .type () == fwd_op :
215
+ sub_block_ids .append (op .attr ("sub_block" ).id )
216
+ sub_op_dict [op .attr ("sub_block" ).id ] = op
217
+ elif op .type () == bwd_op :
218
+ grad_sub_block_ids .append (op .attr ("sub_block" ).id )
219
+ sub_op_dict [op .attr ("sub_block" ).id ] = op
214
220
215
- for i in range (op_size ):
216
- op = block_desc .op (i )
217
- if op .type () == "while" :
218
- while_sub_block_ids .append (op .attr ("sub_block" ).id )
219
- while_op_dict [op .attr ("sub_block" ).id ] = op
220
- elif op .type () == "while_grad" :
221
- while_grad_sub_block_ids .append (op .attr ("sub_block" ).id )
222
- while_op_dict [op .attr ("sub_block" ).id ] = op
221
+ # Find fwd_op/bwd_op block pair
222
+ for grad_id in grad_sub_block_ids :
223
+ fwd_id = pdesc .block (grad_id ).get_forward_block_idx ()
224
+ if fwd_id in sub_block_ids :
225
+ sub_block_id_pair .append ((fwd_id , grad_id ))
226
+ sub_block_ids .remove (fwd_id )
223
227
224
- # Find while/while_grad block pair
225
- for grad_id in while_grad_sub_block_ids :
226
- forward_id = pdesc .block (grad_id ).get_forward_block_idx ()
227
- if forward_id in while_sub_block_ids :
228
- while_block_id_pair .append ((forward_id , grad_id ))
229
- while_sub_block_ids .remove (forward_id )
228
+ # Get fwd_op/bwd_op block ops
229
+ for fwd_id , grad_id in sub_block_id_pair :
230
+ sub_block_ops = []
231
+ sub_block = pdesc .block (fwd_id )
232
+ block_op_size = sub_block .op_size ()
233
+ for i in range (block_op_size ):
234
+ sub_block_ops .append (sub_block .op (i ))
230
235
231
- # Get while/while_grad block ops
232
- for forward_id , grad_id in while_block_id_pair :
233
- while_block_ops = []
234
- while_block = pdesc .block (forward_id )
235
- while_block_op_size = while_block .op_size ()
236
- for i in range (while_block_op_size ):
237
- while_block_ops .append (while_block .op (i ))
236
+ grad_sub_block = pdesc .block (grad_id )
237
+ grad_sub_block_op_size = grad_sub_block .op_size ()
238
+ for i in range (grad_sub_block_op_size ):
239
+ sub_block_ops .append (grad_sub_block .op (i ))
238
240
239
- while_grad_block = pdesc . block ( grad_id )
240
- while_grad_block_op_size = while_grad_block . op_size ( )
241
- for i in range ( while_grad_block_op_size ):
242
- while_block_ops .append (while_grad_block . op ( i ))
241
+ sub_op_output = set ( )
242
+ sub_op_output . update ( sub_op_dict [ fwd_id ]. output_arg_names () )
243
+ sub_op_output . update ( sub_op_dict [ grad_id ]. output_arg_names ())
244
+ ops_list .append (( sub_block_ops , block_op_size , sub_op_output ))
243
245
244
- while_op_output = set ()
245
- while_op_output .update (while_op_dict [forward_id ].output_arg_names ())
246
- while_op_output .update (while_op_dict [grad_id ].output_arg_names ())
246
+ # Process rest fwd_op block ops
247
+ for fwd_id in sub_block_ids :
248
+ sub_block_ops = []
249
+ sub_block = pdesc .block (fwd_id )
250
+ sub_block_op_size = sub_block .op_size ()
251
+ for i in range (sub_block_op_size ):
252
+ sub_block_ops .append (sub_block .op (i ))
253
+ sub_op_output = set ()
254
+ sub_op_output .update (sub_op_dict [fwd_id ].output_arg_names ())
255
+ ops_list .append ((sub_block_ops , sub_block_op_size , sub_op_output ))
256
+ return ops_list
247
257
248
- ops_list .append ((while_block_ops , while_block_op_size , while_op_output ))
249
258
250
- # Process rest while block ops
251
- for forward_id in while_sub_block_ids :
252
- while_block_ops = []
253
- while_block = pdesc .block (forward_id )
254
- while_block_op_size = while_block .op_size ()
255
- for i in range (while_block_op_size ):
256
- while_block_ops .append (while_block .op (i ))
259
+ def _get_cfgs (input_program ):
260
+ ops_list = []
261
+ pdesc = input_program .get_desc ()
262
+ block_desc = pdesc .block (0 )
263
+ op_size = block_desc .op_size ()
264
+ # Get global block ops
265
+ ops_list .append (
266
+ ([block_desc .op (i ) for i in range (op_size )], op_size , set ()))
257
267
258
- while_op_output = set ()
259
- while_op_output . update ( while_op_dict [ forward_id ]. output_arg_names ())
268
+ sub_block_pair = [( "while" , "while_grad" ), ( "parallel_do" ,
269
+ "parallel_do_grad" )]
260
270
261
- ops_list .append (( while_block_ops , while_block_op_size , while_op_output ))
271
+ ops_list .extend ( _process_sub_block_pair ( pdesc , sub_block_pair ))
262
272
263
273
cfgs = [
264
274
ControlFlowGraph (input_program , ops , forward_num , skip_opt )
@@ -268,6 +278,6 @@ def get_cfgs(input_program):
268
278
269
279
270
280
def memory_optimize (input_program ):
271
- cfgs = get_cfgs (input_program )
281
+ cfgs = _get_cfgs (input_program )
272
282
for cfg in cfgs :
273
283
cfg .memory_optimize ()
0 commit comments