29
29
core .VarDesc .VarType .BOOL : 1
30
30
}
31
31
32
- sub_block_ops = [
32
+ SUB_BLOCK_OPS = [
33
33
"while" , "while_grad" , "parallel_do" , "parallel_do_grad" ,
34
34
"conditional_block" , "conditional_block_grad"
35
35
]
36
36
37
+ SUB_BLOCK_PAIR = [("while" , "while_grad" ), ("parallel_do" , "parallel_do_grad" ),
38
+ ("conditional_block" , "conditional_block_grad" )]
39
+
37
40
PRINT_LOG = False
38
41
39
42
40
43
class ControlFlowGraph (object ):
41
- def __init__ (self , Program , ops , forward_num , skip_opt ):
42
- self ._program = Program
44
+ def __init__ (self , program , ops , forward_num , skip_opt ):
45
+ self ._program = program
43
46
self ._ops = ops
44
47
self ._forward_num = forward_num
45
48
self ._successors = defaultdict (set )
@@ -51,14 +54,19 @@ def __init__(self, Program, ops, forward_num, skip_opt):
51
54
self ._skip_opt = skip_opt
52
55
53
56
def _add_connections (self , connections ):
57
+ """Populates _successors and _presuccessors for two neighbor nodes."""
54
58
for node1 , node2 in connections :
55
59
self ._add (node1 , node2 )
56
60
57
61
def _add (self , node1 , node2 ):
58
62
self ._successors [node1 ].add (node2 )
59
63
self ._presuccessors [node2 ].add (node1 )
60
64
65
+ # TODO(panyx0718): We need to have a unified way of building intermediate
66
+ # representation.
61
67
def _build_graph (self ):
68
+ """Build a graph based on op sequence.
69
+ """
62
70
self .op_size = len (self ._ops )
63
71
op_node_connections = [(i , i + 1 ) for i in range (self .op_size - 1 )]
64
72
self ._add_connections (op_node_connections )
@@ -82,22 +90,23 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
82
90
self ._live_out [i ].add (new_name )
83
91
84
92
def _reach_fixed_point (self , live_in , live_out ):
93
+ """Check if the liveness set has stablized."""
85
94
if len (live_in ) != len (self ._live_in ):
86
95
return False
87
96
if len (live_out ) != len (self ._live_out ):
88
97
return False
89
98
for i in range (self .op_size ):
90
- if live_in [i ] != self ._live_in [i ]:
91
- return False
92
- for i in range (self .op_size ):
93
- if live_out [i ] != self ._live_out [i ]:
99
+ if (live_in [i ] != self ._live_in [i ] or
100
+ live_out [i ] != self ._live_out [i ]):
94
101
return False
95
102
return True
96
103
97
104
def _dataflow_analyze (self ):
98
105
self ._build_graph ()
99
106
live_in = defaultdict (set )
100
107
live_out = defaultdict (set )
108
+ # Repeatedly apply liveness updates until the algorithm stablize
109
+ # on a complete set live input vars and live output vars.
101
110
while True :
102
111
for i in range (self .op_size , 0 , - 1 ):
103
112
live_in [i ] = set (self ._live_in [i ])
@@ -141,6 +150,8 @@ def _check_var_validity(self, block_desc, x, is_forward):
141
150
return False
142
151
return True
143
152
153
+ # TODO(panyx0718): This needs to be less hacky. It seems memory optimization
154
+ # doesn't consider vars copied between cpu and gpu.
144
155
def _update_skip_opt_set (self ):
145
156
for i in range (self .op_size ):
146
157
op = self ._ops [i ]
@@ -154,7 +165,7 @@ def release_memory(self):
154
165
bwd_id = 0
155
166
for i in range (self .op_size ):
156
167
op = self ._ops [i ]
157
- if op .type () in sub_block_ops :
168
+ if op .type () in SUB_BLOCK_OPS :
158
169
continue
159
170
block_desc = op .block ()
160
171
is_forward = i < self ._forward_num
@@ -177,24 +188,25 @@ def memory_optimize(self, level=0):
177
188
def compare_shape (x_shape , cache_shape , opt_level ):
178
189
if opt_level == 0 :
179
190
return x_shape == cache_shape
180
- if opt_level == 1 :
191
+ elif opt_level == 1 :
181
192
if (x_shape [0 ] == - 1 ) ^ (cache_shape [0 ] == - 1 ):
182
193
return False
183
194
x_size = abs (reduce (lambda x , y : x * y , x_shape ))
184
195
cache_size = abs (reduce (lambda x , y : x * y , cache_shape ))
185
196
if x_size <= cache_size :
186
197
return True
198
+ else :
199
+ raise ValueError ("only support opt_level 0 or 1." )
187
200
return False
188
201
189
202
self ._dataflow_analyze ()
190
203
self ._update_skip_opt_set ()
191
204
self .pool = []
192
205
for i in range (self .op_size ):
193
206
op = self ._ops [i ]
194
- if op .type () in sub_block_ops :
207
+ if op .type () in SUB_BLOCK_OPS :
195
208
continue
196
209
block_desc = op .block ()
197
- self .current_block_desc = block_desc
198
210
is_forward = i < self ._forward_num
199
211
if self .pool :
200
212
defs_can_optimize = filter (
@@ -211,37 +223,40 @@ def compare_shape(x_shape, cache_shape, opt_level):
211
223
for index , cache_pair in enumerate (self .pool ):
212
224
cache_var = cache_pair [0 ]
213
225
cache_shape = cache_pair [1 ]
214
- if compare_shape (x_shape , cache_shape , level ):
215
- if self ._has_var (block_desc , cache_var , is_forward ):
216
- x_dtype = self ._find_var (block_desc , x ,
217
- is_forward ).dtype ()
218
- cache_dtype = self ._find_var (
219
- block_desc , cache_var , is_forward ).dtype ()
220
- # TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
221
- # and dtype_to_size[cache_dtype]
222
- if x_dtype == cache_dtype :
223
- if PRINT_LOG :
224
- print (
225
- ("Hit Cache !!!! cache pool index "
226
- "is %d, var name is %s, "
227
- "cached var name is %s, "
228
- "var shape is %s " ) %
229
- (index , x , cache_var ,
230
- str (cache_shape )))
231
- self .pool .pop (index )
232
- if x == cache_var :
233
- break
234
- _rename_arg_ (
235
- self ._ops , x , cache_var , begin_idx = i )
236
- self ._program .block (block_desc .id ).var (
237
- str (x )).desc = self ._find_var (
238
- block_desc , cache_var , is_forward )
239
- self ._update_graph (
240
- x , cache_var , begin_idx = i )
241
- break
242
-
243
- in_diff , out_diff = self ._get_diff (self ._live_in [i ],
244
- self ._live_out [i ])
226
+ if not compare_shape (x_shape , cache_shape , level ):
227
+ continue
228
+
229
+ if not self ._has_var (block_desc , cache_var , is_forward ):
230
+ continue
231
+
232
+ x_dtype = self ._find_var (block_desc , x ,
233
+ is_forward ).dtype ()
234
+ cache_dtype = self ._find_var (block_desc , cache_var ,
235
+ is_forward ).dtype ()
236
+ # TODO(qijun): actually, we should compare
237
+ # dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
238
+ if x_dtype != cache_dtype :
239
+ continue
240
+
241
+ if PRINT_LOG :
242
+ print (("Hit Cache !!!! cache pool index "
243
+ "is %d, var name is %s, "
244
+ "cached var name is %s, "
245
+ "var shape is %s " ) % (index , x , cache_var ,
246
+ str (cache_shape )))
247
+ self .pool .pop (index )
248
+ if x == cache_var :
249
+ break
250
+ # Rename the var to the cache var already with
251
+ # memory allocated in order to reuse the memory.
252
+ _rename_arg_ (self ._ops , x , cache_var , begin_idx = i )
253
+ self ._program .block (block_desc .id ).var (str (
254
+ x )).desc = self ._find_var (block_desc , cache_var ,
255
+ is_forward )
256
+ self ._update_graph (x , cache_var , begin_idx = i )
257
+ break
258
+
259
+ in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
245
260
can_optimize = filter (
246
261
lambda x : self ._check_var_validity (block_desc , x , is_forward ),
247
262
in_diff )
@@ -252,6 +267,19 @@ def compare_shape(x_shape, cache_shape, opt_level):
252
267
253
268
254
269
def _process_sub_block_pair (pdesc , sub_block_pair ):
270
+ """Creates a list of tuple each of which tracks info of a subblock.
271
+
272
+ Note: this function doesn't handle nested subblocks yet.
273
+ TODO(panyx0718): assert if case nested subblocks happen.
274
+
275
+ :param pdesc: ProgramDesc.
276
+ :param sub_block_pair: A list op pairs. Each op pair is the forward
277
+ op and backward op. The ops in the list are special that they contain
278
+ a subblock of ops.
279
+ :return: A list of tuples, each tuple is (all ops in a subblock pair
280
+ including forward and backward, number of forward ops,
281
+ all output args names of the ops in the subblock pairs).
282
+ """
255
283
ops_list = []
256
284
block_desc = pdesc .block (0 )
257
285
op_size = block_desc .op_size ()
@@ -308,6 +336,11 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
308
336
309
337
310
338
def _get_cfgs (input_program ):
339
+ """Process each block and create ControlFlowGraph for each of them.
340
+
341
+ :param input_program: Program object.
342
+ :return: A list of ControlFlowGraph, each corresponds to a block.
343
+ """
311
344
ops_list = []
312
345
pdesc = input_program .get_desc ()
313
346
block_desc = pdesc .block (0 )
@@ -316,11 +349,8 @@ def _get_cfgs(input_program):
316
349
ops_list .append (
317
350
([block_desc .op (i ) for i in range (op_size )], op_size , set ()))
318
351
319
- sub_block_pair = [("while" , "while_grad" ), ("parallel_do" ,
320
- "parallel_do_grad" ),
321
- ("conditional_block" , "conditional_block_grad" )]
322
-
323
- ops_list .extend (_process_sub_block_pair (pdesc , sub_block_pair ))
352
+ # Only process one level of nested subblock.
353
+ ops_list .extend (_process_sub_block_pair (pdesc , SUB_BLOCK_PAIR ))
324
354
325
355
cfgs = [
326
356
ControlFlowGraph (input_program , ops , forward_num , skip_opt )
@@ -330,6 +360,17 @@ def _get_cfgs(input_program):
330
360
331
361
332
362
def memory_optimize (input_program , print_log = False , level = 0 ):
363
+ """Optimize memory by reusing var memory.
364
+
365
+ Note: it doesn't not support subblock nested in subblock.
366
+
367
+ :param input_program: Input Program
368
+ :param print_log: whether to print debug log.
369
+ :param level: If level=0, reuse if the shape is completely equal, o
370
+ :return:
371
+ """
372
+ if level != 0 and level != 1 :
373
+ raise ValueError ("only support opt_level 0 or 1." )
333
374
global PRINT_LOG
334
375
PRINT_LOG = print_log
335
376
cfgs = _get_cfgs (input_program )
0 commit comments