@@ -77,6 +77,9 @@ def _build_graph(self):
77
77
for i in range (self .op_size ):
78
78
self ._uses [i ].update (self ._ops [i ].input_arg_names ())
79
79
self ._defs [i ].update (self ._ops [i ].output_arg_names ())
80
+ self ._live_in [i ] = self ._uses [i ]
81
+ # print(self._successors)
82
+ # print(self._presuccessors)
80
83
81
84
def _update_graph (self , old_name , new_name , begin_idx = 0 ):
82
85
for i in range (begin_idx , self .op_size ):
@@ -86,12 +89,18 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
86
89
if old_name in self ._defs [i ]:
87
90
self ._defs [i ].remove (old_name )
88
91
self ._defs [i ].add (new_name )
92
+ # for i in range(begin_idx, -1, -1):
89
93
if old_name in self ._live_in [i ]:
90
94
self ._live_in [i ].remove (old_name )
91
- self ._live_out [i ].add (new_name )
95
+ self ._live_in [i ].add (new_name )
96
+ # if old_name == "concat_3.tmp_0@GRAD":
97
+ # print("new_name", new_name)
98
+ # print("live_in ", i , self._live_in[i])
92
99
if old_name in self ._live_out [i ]:
93
100
self ._live_out [i ].remove (old_name )
94
101
self ._live_out [i ].add (new_name )
102
+ # if old_name == "concat_3.tmp_0@GRAD":
103
+ # print("live_out ", i , self._live_out[i])
95
104
96
105
def _reach_fixed_point (self , live_in , live_out ):
97
106
"""Check if the liveness set has stablized."""
@@ -105,22 +114,40 @@ def _reach_fixed_point(self, live_in, live_out):
105
114
return False
106
115
return True
107
116
117
+ # def _dataflow_analyze(self):
118
+ # self._build_graph()
119
+ # live_in = defaultdict(set)
120
+ # live_out = defaultdict(set)
121
+ # # Repeatedly apply liveness updates until the algorithm stablize
122
+ # # on a complete set live input vars and live output vars.
123
+ # counter = 0
124
+ # print(self._successors)
125
+ # while True:
126
+ # counter += 1
127
+ # for i in reversed(list(range(self.op_size))):
128
+ # live_in[i] = set(self._live_in[i])
129
+ # live_out[i] = set(self._live_out[i])
130
+ # for s in self._successors[i]:
131
+ # self._live_out[i] |= self._live_in[s]
132
+ # self._live_in[i] = self._uses[i] | (
133
+ # self._live_out[i] - self._defs[i])
134
+ # if self._reach_fixed_point(live_in, live_out):
135
+ # break
136
+
108
137
def _dataflow_analyze (self ):
109
138
self ._build_graph ()
110
139
live_in = defaultdict (set )
111
- live_out = defaultdict (set )
112
- # Repeatedly apply liveness updates until the algorithm stablize
113
- # on a complete set live input vars and live output vars.
114
- while True :
115
- for i in reversed (list (range (self .op_size ))):
116
- live_in [i ] = set (self ._live_in [i ])
117
- live_out [i ] = set (self ._live_out [i ])
118
- for s in self ._successors [i ]:
119
- self ._live_out [i ] |= self ._live_in [s ]
120
- self ._live_in [i ] = self ._uses [i ] | (
121
- self ._live_out [i ] - self ._defs [i ])
122
- if self ._reach_fixed_point (live_in , live_out ):
123
- break
140
+ worklist = list (range (len (self ._ops ) - 1 , - 1 , - 1 ))
141
+ while worklist :
142
+ i = worklist .pop (0 )
143
+ live_in [i ] = set (self ._live_in [i ])
144
+ for s in self ._successors [i ]:
145
+ self ._live_out [i ] |= self ._live_in [s ]
146
+ self ._live_in [i ] = self ._uses [i ] | (
147
+ self ._live_out [i ] - self ._defs [i ])
148
+ if live_in [i ] != self ._live_in [i ]:
149
+ for d in self ._presuccessors [i ]:
150
+ worklist .append (d )
124
151
125
152
def _get_diff (self , a , b ):
126
153
u = a & b
@@ -218,6 +245,17 @@ def compare_shape(x_shape, cache_shape, opt_level):
218
245
continue
219
246
block_desc = op .block ()
220
247
is_forward = i < self ._forward_num
248
+ in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
249
+ can_optimize = [
250
+ x for x in in_diff
251
+ if self ._check_var_validity (block_desc , x , is_forward )
252
+ ]
253
+ if can_optimize :
254
+ for var_name in can_optimize :
255
+ self .pool .append ((var_name , self ._find_var (
256
+ block_desc , var_name , is_forward ).shape ()))
257
+ # print(op.type(), i, self.pool)
258
+ # print(self._live_in[i])
221
259
if self .pool :
222
260
defs_can_optimize = [
223
261
x for x in self ._defs [i ]
@@ -249,21 +287,24 @@ def compare_shape(x_shape, cache_shape, opt_level):
249
287
if x_dtype != cache_dtype :
250
288
continue
251
289
290
+ self .pool .pop (index )
291
+ if x == cache_var :
292
+ break
293
+
252
294
if PRINT_LOG :
253
295
print (("Hit Cache !!!! cache pool index "
254
296
"is %d, var name is %s, "
255
297
"cached var name is %s, "
256
298
"var shape is %s " ) % (index , x , cache_var ,
257
299
str (cache_shape )))
258
- self .pool .pop (index )
259
- if x == cache_var :
260
- break
261
300
# Rename the var to the cache var already with
262
301
# memory allocated in order to reuse the memory.
263
302
_rename_arg_ (self ._ops , x , cache_var , begin_idx = i )
264
303
self ._program .block (block_desc .id ).var (cpt .to_text (
265
304
x )).desc = self ._find_var (block_desc , cache_var ,
266
305
is_forward )
306
+ if x == "concat_3.tmp_0@GRAD" :
307
+ print ("Update Graph" , i )
267
308
self ._update_graph (x , cache_var , begin_idx = i )
268
309
break
269
310
@@ -272,10 +313,13 @@ def compare_shape(x_shape, cache_shape, opt_level):
272
313
x for x in in_diff
273
314
if self ._check_var_validity (block_desc , x , is_forward )
274
315
]
316
+ keys = set ([key for key ,shape in self .pool ])
275
317
if can_optimize :
276
318
for var_name in can_optimize :
277
- self .pool .append ((var_name , self ._find_var (
278
- block_desc , var_name , is_forward ).shape ()))
319
+ if var_name not in keys :
320
+ self .pool .append ((var_name , self ._find_var (
321
+ block_desc , var_name , is_forward ).shape ()))
322
+ # print(op.type(), i, self.pool)
279
323
280
324
281
325
def _process_sub_block_pair (pdesc , sub_block_pair ):
0 commit comments