14
14
15
15
from __future__ import print_function
16
16
17
- from collections import defaultdict , OrderedDict , Callable
17
+ from collections import defaultdict , MutableSet
18
18
from .. import core
19
19
from ... import compat as cpt
20
- from ..framework import Program , default_main_program , Parameter , Variable
20
+ from ..framework import Program , default_main_program , Parameter , Variable , core
21
21
from ..backward import _rename_arg_
22
22
from functools import reduce
23
23
from six .moves import range
44
44
PRINT_LOG = False
45
45
46
46
47
+ class OrderedSet (MutableSet ):
48
+ def __init__ (self , iterable = None ):
49
+ self .end = end = []
50
+ end += [None , end , end ] # sentinel node for doubly linked list
51
+ self .map = {} # key --> [key, prev, next]
52
+ if iterable is not None :
53
+ self |= iterable
54
+
55
+ def __len__ (self ):
56
+ return len (self .map )
57
+
58
+ def __contains__ (self , key ):
59
+ return key in self .map
60
+
61
+ def add (self , key ):
62
+ if key not in self .map :
63
+ end = self .end
64
+ curr = end [1 ]
65
+ curr [2 ] = end [1 ] = self .map [key ] = [key , curr , end ]
66
+
67
+ def update (self , other ):
68
+ for e in other :
69
+ self .add (e )
70
+
71
+ def discard (self , key ):
72
+ if key in self .map :
73
+ key , prev , next = self .map .pop (key )
74
+ prev [2 ] = next
75
+ next [1 ] = prev
76
+
77
+ def remove (self , key ):
78
+ self .discard (key )
79
+
80
+ def __iter__ (self ):
81
+ end = self .end
82
+ curr = end [2 ]
83
+ while curr is not end :
84
+ yield curr [0 ]
85
+ curr = curr [2 ]
86
+
87
+ def __reversed__ (self ):
88
+ end = self .end
89
+ curr = end [1 ]
90
+ while curr is not end :
91
+ yield curr [0 ]
92
+ curr = curr [1 ]
93
+
94
+ def pop (self , last = True ):
95
+ if not self :
96
+ raise KeyError ('set is empty' )
97
+ key = self .end [1 ][0 ] if last else self .end [2 ][0 ]
98
+ self .discard (key )
99
+ return key
100
+
101
+ def __repr__ (self ):
102
+ if not self :
103
+ return '%s()' % (self .__class__ .__name__ , )
104
+ return '%s(%r)' % (self .__class__ .__name__ , list (self ))
105
+
106
+ def __eq__ (self , other ):
107
+ if isinstance (other , OrderedSet ):
108
+ return len (self ) == len (other ) and list (self ) == list (other )
109
+ return set (self ) == set (other )
110
+
111
+
47
112
class ControlFlowGraph (object ):
48
113
def __init__ (self , program , ops , forward_num , skip_opt ):
49
114
self ._program = program
50
115
self ._ops = ops
51
116
self ._forward_num = forward_num
52
- self ._successors = defaultdict (set )
53
- self ._presuccessors = defaultdict (set )
54
- self ._uses = defaultdict (set )
55
- self ._defs = defaultdict (set )
56
- self ._live_in = defaultdict (set )
57
- self ._live_out = defaultdict (set )
117
+ self ._successors = defaultdict (OrderedSet )
118
+ self ._presuccessors = defaultdict (OrderedSet )
119
+ self ._uses = defaultdict (OrderedSet )
120
+ self ._defs = defaultdict (OrderedSet )
121
+ self ._live_in = defaultdict (OrderedSet )
122
+ self ._live_out = defaultdict (OrderedSet )
58
123
self ._skip_opt = skip_opt
59
124
self .pool = []
60
125
@@ -116,7 +181,7 @@ def _fill_pool(self, i, is_forward):
116
181
# NOTE: must sort the in_diff set for cases that get different cache var.
117
182
# FIXME(typhoonzero): maybe use a "sorted set" is better than this.
118
183
can_optimize = [
119
- x for x in sorted ( list ( in_diff ))
184
+ x for x in in_diff
120
185
if self ._check_var_validity (block_desc , x , is_forward )
121
186
]
122
187
if can_optimize :
@@ -224,7 +289,7 @@ def compare_shape(x_shape, cache_shape, opt_level):
224
289
if self .pool :
225
290
# NOTE: must sort the in_diff set for cases that get different cache var.
226
291
defs_can_optimize = [
227
- x for x in sorted ( list ( self ._defs [i ]))
292
+ x for x in self ._defs [i ]
228
293
if self ._check_var_validity (block_desc , x , is_forward )
229
294
]
230
295
out_pair = [
@@ -381,7 +446,19 @@ def _get_cfgs(input_program):
381
446
return cfgs
382
447
383
448
384
- def memory_optimize (input_program , skip_opt_set = None , print_log = False , level = 0 ):
449
+ def _is_opt_role_op (op ):
450
+ op_maker = core .op_proto_and_checker_maker
451
+ optimize_role = core .op_proto_and_checker_maker .OpRole .Optimize
452
+ if op_maker .kOpRoleAttrName () in op .attr_names and \
453
+ int (op .all_attrs ()[op_maker .kOpRoleAttrName ()]) == int (optimize_role ):
454
+ return True
455
+
456
+
457
+ def memory_optimize (input_program ,
458
+ skip_opt_set = None ,
459
+ print_log = False ,
460
+ level = 0 ,
461
+ skip_grads = False ):
385
462
"""Optimize memory by reusing var memory.
386
463
387
464
Note: it doesn't not support subblock nested in subblock.
@@ -398,6 +475,19 @@ def memory_optimize(input_program, skip_opt_set=None, print_log=False, level=0):
398
475
raise ValueError ("only support opt_level 0 or 1." )
399
476
global PRINT_LOG
400
477
PRINT_LOG = print_log
478
+ if skip_grads :
479
+ grad_set = set ()
480
+ OP_ROLE_VAR = core .op_proto_and_checker_maker .kOpRoleVarAttrName ()
481
+ for op in input_program .global_block ().ops :
482
+ if _is_opt_role_op (op ):
483
+ if op .attr (OP_ROLE_VAR ):
484
+ grad_name = op .attr (OP_ROLE_VAR )[1 ]
485
+ grad_set .add (grad_name )
486
+ if not skip_opt_set :
487
+ skip_opt_set = grad_set
488
+ else :
489
+ skip_opt_set .update (grad_set )
490
+
401
491
cfgs = _get_cfgs (input_program )
402
492
for cfg in cfgs :
403
493
cfg .memory_optimize (skip_opt_set = skip_opt_set , level = level )
0 commit comments