47
47
class ControlFlowGraph (object ):
48
48
def __init__ (self , program , ops , forward_num , skip_opt ):
49
49
self ._program = program
50
+ self ._dup_program = program .clone ()
50
51
self ._ops = ops
51
52
self ._forward_num = forward_num
52
53
self ._successors = defaultdict (set )
@@ -56,6 +57,7 @@ def __init__(self, program, ops, forward_num, skip_opt):
56
57
self ._live_in = defaultdict (set )
57
58
self ._live_out = defaultdict (set )
58
59
self ._skip_opt = skip_opt
60
+ self .pool = []
59
61
60
62
def _add_connections (self , connections ):
61
63
"""Populates _successors and _presuccessors for two neighbor nodes."""
@@ -78,8 +80,6 @@ def _build_graph(self):
78
80
self ._uses [i ].update (self ._ops [i ].input_arg_names ())
79
81
self ._defs [i ].update (self ._ops [i ].output_arg_names ())
80
82
self ._live_in [i ] = self ._uses [i ]
81
- # print(self._successors)
82
- # print(self._presuccessors)
83
83
84
84
def _update_graph (self , old_name , new_name , begin_idx = 0 ):
85
85
for i in range (begin_idx , self .op_size ):
@@ -89,50 +89,13 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
89
89
if old_name in self ._defs [i ]:
90
90
self ._defs [i ].remove (old_name )
91
91
self ._defs [i ].add (new_name )
92
- # for i in range(begin_idx, -1, -1):
93
92
if old_name in self ._live_in [i ]:
94
93
self ._live_in [i ].remove (old_name )
95
94
self ._live_in [i ].add (new_name )
96
- # if old_name == "concat_3.tmp_0@GRAD":
97
- # print("new_name", new_name)
98
- # print("live_in ", i , self._live_in[i])
99
95
if old_name in self ._live_out [i ]:
100
96
self ._live_out [i ].remove (old_name )
101
97
self ._live_out [i ].add (new_name )
102
- # if old_name == "concat_3.tmp_0@GRAD":
103
- # print("live_out ", i , self._live_out[i])
104
-
105
- def _reach_fixed_point (self , live_in , live_out ):
106
- """Check if the liveness set has stablized."""
107
- if len (live_in ) != len (self ._live_in ):
108
- return False
109
- if len (live_out ) != len (self ._live_out ):
110
- return False
111
- for i in range (self .op_size ):
112
- if (live_in [i ] != self ._live_in [i ] or
113
- live_out [i ] != self ._live_out [i ]):
114
- return False
115
- return True
116
98
117
- # def _dataflow_analyze(self):
118
- # self._build_graph()
119
- # live_in = defaultdict(set)
120
- # live_out = defaultdict(set)
121
- # # Repeatedly apply liveness updates until the algorithm stablize
122
- # # on a complete set live input vars and live output vars.
123
- # counter = 0
124
- # print(self._successors)
125
- # while True:
126
- # counter += 1
127
- # for i in reversed(list(range(self.op_size))):
128
- # live_in[i] = set(self._live_in[i])
129
- # live_out[i] = set(self._live_out[i])
130
- # for s in self._successors[i]:
131
- # self._live_out[i] |= self._live_in[s]
132
- # self._live_in[i] = self._uses[i] | (
133
- # self._live_out[i] - self._defs[i])
134
- # if self._reach_fixed_point(live_in, live_out):
135
- # break
136
99
137
100
def _dataflow_analyze (self ):
138
101
self ._build_graph ()
@@ -149,6 +112,20 @@ def _dataflow_analyze(self):
149
112
for d in self ._presuccessors [i ]:
150
113
worklist .append (d )
151
114
115
+ def _fill_pool (self , i , is_forward ):
116
+ block_desc = self ._ops [i ].block ()
117
+ in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
118
+ can_optimize = [
119
+ x for x in in_diff
120
+ if self ._check_var_validity (block_desc , x , is_forward )
121
+ ]
122
+ if can_optimize :
123
+ for var_name in can_optimize :
124
+ cache = (var_name , self ._find_var (
125
+ block_desc , var_name , is_forward ).shape ())
126
+ if cache not in self .pool :
127
+ self .pool .append (cache )
128
+
152
129
def _get_diff (self , a , b ):
153
130
u = a & b
154
131
return a - u , b - u
@@ -238,24 +215,15 @@ def compare_shape(x_shape, cache_shape, opt_level):
238
215
# update skip set to meet users' demand
239
216
if skip_opt_set :
240
217
self ._skip_opt .update (skip_opt_set )
241
- self .pool = []
218
+ # self.pool = []
242
219
for i in range (self .op_size ):
243
220
op = self ._ops [i ]
244
221
if op .type () in SUB_BLOCK_OPS :
245
222
continue
246
223
block_desc = op .block ()
247
224
is_forward = i < self ._forward_num
248
- in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
249
- can_optimize = [
250
- x for x in in_diff
251
- if self ._check_var_validity (block_desc , x , is_forward )
252
- ]
253
- if can_optimize :
254
- for var_name in can_optimize :
255
- self .pool .append ((var_name , self ._find_var (
256
- block_desc , var_name , is_forward ).shape ()))
225
+ self ._fill_pool (i , is_forward )
257
226
# print(op.type(), i, self.pool)
258
- # print(self._live_in[i])
259
227
if self .pool :
260
228
defs_can_optimize = [
261
229
x for x in self ._defs [i ]
@@ -266,60 +234,57 @@ def compare_shape(x_shape, cache_shape, opt_level):
266
234
for x in defs_can_optimize
267
235
]
268
236
for x , x_shape in out_pair :
237
+ if (x , x_shape ) in self .pool :
238
+ raise ValueError ("x in pool" )
269
239
# If x is both in uses and defs, it can not be optimized!
270
240
if x in self ._uses [i ]:
241
+ # print(self.pool, op.type(), cpt.to_text(x))
242
+ # raise ValueError("x in use!", cpt.to_text(x))
271
243
continue
272
244
for index , cache_pair in enumerate (self .pool ):
273
245
cache_var = cache_pair [0 ]
274
246
cache_shape = cache_pair [1 ]
275
- if not compare_shape (x_shape , cache_shape , level ):
276
- continue
277
-
278
247
if not self ._has_var (block_desc , cache_var , is_forward ):
279
- continue
248
+ raise ValueError ("cache" , cpt .to_text (cache_var ), " Not exists!" )
249
+ if x == cache_var :
250
+ raise ValueError ("x : " , cpt .to_text (x ), " cache : " , cpt .to_text (cache_var ), " is same var!" )
280
251
281
252
x_dtype = self ._find_var (block_desc , x ,
282
253
is_forward ).dtype ()
283
254
cache_dtype = self ._find_var (block_desc , cache_var ,
284
255
is_forward ).dtype ()
256
+
257
+ if not compare_shape (x_shape , cache_shape , level ):
258
+ continue
285
259
# TODO(qijun): actually, we should compare
286
260
# dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
287
261
if x_dtype != cache_dtype :
288
262
continue
289
263
290
- self .pool .pop (index )
291
- if x == cache_var :
292
- break
293
-
294
264
if PRINT_LOG :
295
265
print (("Hit Cache !!!! cache pool index "
296
266
"is %d, var name is %s, "
297
267
"cached var name is %s, "
298
268
"var shape is %s " ) % (index , x , cache_var ,
299
269
str (cache_shape )))
270
+ self .pool .pop (index )
300
271
# Rename the var to the cache var already with
301
272
# memory allocated in order to reuse the memory.
302
273
_rename_arg_ (self ._ops , x , cache_var , begin_idx = i )
303
- self ._program .block (block_desc .id ).var (cpt .to_text (
304
- x )).desc = self ._find_var (block_desc , cache_var ,
305
- is_forward )
306
- if x == "concat_3.tmp_0@GRAD" :
307
- print ("Update Graph" , i )
274
+ self ._program .block (block_desc .id )._remove_var (cpt .to_text (
275
+ x ))
276
+ # if str(self._program) != str(self._dup_program):
277
+ # with open("./program_middle", "w") as f:
278
+ # f.write(str(self._program))
279
+ # f.flush()
280
+ # exit(0)
281
+ # self._program.block(block_desc.id).var(cpt.to_text(
282
+ # x)).desc = self._find_var(block_desc, cache_var,
283
+ # is_forward)
308
284
self ._update_graph (x , cache_var , begin_idx = i )
309
285
break
286
+ # self._fill_pool(i, is_forward)
310
287
311
- in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
312
- can_optimize = [
313
- x for x in in_diff
314
- if self ._check_var_validity (block_desc , x , is_forward )
315
- ]
316
- keys = set ([key for key ,shape in self .pool ])
317
- if can_optimize :
318
- for var_name in can_optimize :
319
- if var_name not in keys :
320
- self .pool .append ((var_name , self ._find_var (
321
- block_desc , var_name , is_forward ).shape ()))
322
- # print(op.type(), i, self.pool)
323
288
324
289
325
290
def _process_sub_block_pair (pdesc , sub_block_pair ):
0 commit comments