Skip to content

Commit d76fcb6

Browse files
authored
Memory optimization on Dynamic RNN (#7599)
* limit variable type to lod tensor in memory optimization transpiler * refine policy * support while operator * fix random seed and training data order * refine get_cfgs method to support multi while operators * refine codes
1 parent f6a4c3e commit d76fcb6

File tree

10 files changed

+332
-71
lines changed

10 files changed

+332
-71
lines changed

paddle/framework/block_desc.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ std::vector<VarDesc *> BlockDesc::AllVars() const {
7575

7676
OpDesc *BlockDesc::AppendOp() {
7777
need_update_ = true;
78-
ops_.emplace_back(new OpDesc());
78+
ops_.emplace_back(new OpDesc(this));
7979
return ops_.back().get();
8080
}
8181

@@ -86,7 +86,7 @@ void BlockDesc::AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc) {
8686

8787
OpDesc *BlockDesc::PrependOp() {
8888
need_update_ = true;
89-
ops_.emplace_front(new OpDesc());
89+
ops_.emplace_front(new OpDesc(this));
9090
return ops_.front().get();
9191
}
9292

@@ -153,7 +153,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
153153
vars_[var_desc.name()].reset(new VarDesc(var_desc));
154154
}
155155
for (const proto::OpDesc &op_desc : desc_->ops()) {
156-
ops_.emplace_back(new OpDesc(op_desc, prog));
156+
ops_.emplace_back(new OpDesc(op_desc, prog, this));
157157
}
158158
}
159159

@@ -162,7 +162,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
162162
: prog_(prog), desc_(desc) {
163163
need_update_ = true;
164164
for (auto &op : other.ops_) {
165-
ops_.emplace_back(new OpDesc(*op));
165+
ops_.emplace_back(new OpDesc(*op, this));
166166
}
167167

168168
for (auto &it : other.vars_) {

paddle/framework/op_desc.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
9797
need_update_ = true;
9898
}
9999

100-
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
100+
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block)
101101
: desc_(desc), need_update_(false) {
102102
// restore inputs_
103103
int input_size = desc_.inputs_size();
@@ -131,6 +131,7 @@ OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
131131
attrs_[attr_name] = prog->MutableBlock(bid);
132132
}
133133
}
134+
this->block_ = block;
134135
}
135136

136137
proto::OpDesc *OpDesc::Proto() {

paddle/framework/op_desc.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,21 @@ namespace framework {
2525

2626
class BlockDesc;
2727
class ProgramDesc;
28-
2928
class OpDesc {
3029
public:
3130
OpDesc() {}
3231

3332
OpDesc(const std::string &type, const VariableNameMap &inputs,
3433
const VariableNameMap &outputs, const AttributeMap &attrs);
3534

36-
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog);
35+
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block);
36+
37+
explicit OpDesc(BlockDesc *block) : block_(block) {}
38+
39+
OpDesc(const OpDesc &other, BlockDesc *block) {
40+
*this = other;
41+
block_ = block;
42+
}
3743

3844
void CopyFrom(const OpDesc &op_desc);
3945

@@ -117,6 +123,10 @@ class OpDesc {
117123

118124
void Flush();
119125

126+
BlockDesc *Block() { return this->block_; }
127+
128+
void SetBlock(BlockDesc *block) { this->block_ = block; }
129+
120130
private:
121131
template <typename MapType>
122132
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
@@ -129,6 +139,7 @@ class OpDesc {
129139
}
130140

131141
proto::OpDesc desc_;
142+
BlockDesc *block_; // not_own
132143
// input arg name => input variable names
133144
VariableNameMap inputs_;
134145
// output arg name => output variable names

paddle/framework/var_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class VarDesc {
6666

6767
std::string Name() const { return desc_.name(); }
6868

69+
void SetName(std::string name) { desc_.set_name(name); }
70+
6971
void SetShape(const std::vector<int64_t> &dims);
7072

7173
void SetDataType(proto::DataType data_type);

paddle/pybind/protobuf.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ void BindVarDsec(py::module &m) {
212212
return name;
213213
},
214214
py::return_value_policy::reference)
215+
.def("set_name", &VarDesc::SetName)
215216
.def("set_shape", &VarDesc::SetShape)
216217
.def("set_dtype", &VarDesc::SetDataType)
217218
.def("shape", &VarDesc::Shape, py::return_value_policy::reference)
@@ -280,7 +281,8 @@ void BindOpDesc(py::module &m) {
280281
.def("check_attrs", &OpDesc::CheckAttrs)
281282
.def("infer_shape", &OpDesc::InferShape)
282283
.def("infer_var_type", &OpDesc::InferVarType)
283-
.def("serialize_to_string", SerializeMessage<OpDesc>);
284+
.def("serialize_to_string", SerializeMessage<OpDesc>)
285+
.def("block", &OpDesc::Block, py::return_value_policy::reference);
284286
}
285287

286288
} // namespace pybind

python/paddle/v2/fluid/memory_optimization_transpiler.py

Lines changed: 130 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@
3131

3232

3333
class ControlFlowGraph(object):
34-
def __init__(self, Program):
34+
def __init__(self, Program, ops, forward_num):
3535
self._program = Program
36-
self._succesors = defaultdict(set)
37-
self._presucessors = defaultdict(set)
36+
self._ops = ops
37+
self._forward_num = forward_num
38+
self._successors = defaultdict(set)
39+
self._presuccessors = defaultdict(set)
3840
self._uses = defaultdict(set)
3941
self._defs = defaultdict(set)
4042
self._live_in = defaultdict(set)
@@ -45,25 +47,16 @@ def _add_connections(self, connections):
4547
self._add(node1, node2)
4648

4749
def _add(self, node1, node2):
48-
self._succesors[node1].add(node2)
49-
self._presucessors[node2].add(node1)
50+
self._successors[node1].add(node2)
51+
self._presuccessors[node2].add(node1)
5052

5153
def _build_graph(self):
52-
program_desc = self._program.get_desc()
53-
block_size = program_desc.num_blocks()
54-
55-
# TODO(qijun) handle Program with if/while operators
56-
self.global_block_desc = program_desc.block(0)
57-
self.op_size = self.global_block_desc.op_size()
58-
54+
self.op_size = len(self._ops)
5955
op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
6056
self._add_connections(op_node_connections)
61-
62-
self.ops = [self.global_block_desc.op(i) for i in range(self.op_size)]
63-
6457
for i in range(self.op_size):
65-
self._uses[i].update(self.ops[i].input_arg_names())
66-
self._defs[i].update(self.ops[i].output_arg_names())
58+
self._uses[i].update(self._ops[i].input_arg_names())
59+
self._defs[i].update(self._ops[i].output_arg_names())
6760

6861
def _update_graph(self, old_name, new_name, begin_idx=0):
6962
for i in range(begin_idx, self.op_size):
@@ -103,7 +96,7 @@ def _dataflow_analyze(self):
10396
live_out[i] = set(self._live_out[i])
10497
self._live_in[i] = self._uses[i] | (
10598
self._live_out[i] - self._defs[i])
106-
for s in self._succesors[i]:
99+
for s in self._successors[i]:
107100
self._live_out[i] |= self._live_in[s]
108101

109102
if self._reach_fixed_point(live_in, live_out):
@@ -113,60 +106,147 @@ def _get_diff(self, a, b):
113106
u = a & b
114107
return a - u, b - u
115108

109+
def _has_var(self, block_desc, var_name, is_forward):
110+
if is_forward:
111+
return block_desc.has_var(str(var_name))
112+
else:
113+
return block_desc.has_var_recursive(str(var_name))
114+
115+
def _find_var(self, block_desc, var_name, is_forward):
116+
if is_forward:
117+
return block_desc.find_var(str(var_name))
118+
else:
119+
return block_desc.find_var_recursive(str(var_name))
120+
116121
def memory_optimize(self):
122+
def check_var_validity(block_desc, x, is_forward):
123+
if str(x) == "@EMPTY@":
124+
return False
125+
if not self._has_var(block_desc, x, is_forward):
126+
return False
127+
if self._find_var(block_desc, x, is_forward).persistable():
128+
return False
129+
if self._find_var(
130+
block_desc, x,
131+
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
132+
return False
133+
return True
134+
117135
self._build_graph()
118136
self._dataflow_analyze()
119137
self.pool = []
120138
for i in range(self.op_size):
139+
op = self._ops[i]
140+
if op.type() == "while" or op.type() == "while_grad":
141+
continue
142+
block_desc = op.block()
143+
is_forward = i < self._forward_num
121144
if self.pool:
122-
out_pair = [(x, self.global_block_desc.var(str(x)).shape())
123-
for x in self._defs[i]]
145+
defs_can_optimize = filter(
146+
lambda x: check_var_validity(block_desc, x, is_forward),
147+
self._defs[i])
148+
out_pair = [
149+
(x, self._find_var(block_desc, x, is_forward).shape())
150+
for x in defs_can_optimize
151+
]
124152
for x, x_shape in out_pair:
125-
if not self.global_block_desc.var(str(x)).persistable():
126-
for index, cache_pair in enumerate(self.pool):
127-
cache_var = cache_pair[0]
128-
cache_shape = cache_pair[1]
129-
if x_shape == cache_shape:
130-
x_dtype = self.global_block_desc.var(str(
131-
x)).dtype()
132-
cache_dtype = self.global_block_desc.var(
133-
str(cache_var)).dtype()
153+
for index, cache_pair in enumerate(self.pool):
154+
cache_var = cache_pair[0]
155+
cache_shape = cache_pair[1]
156+
if x_shape == cache_shape:
157+
if self._has_var(block_desc, cache_var, is_forward):
158+
x_dtype = self._find_var(block_desc, x,
159+
is_forward).dtype()
160+
cache_dtype = self._find_var(
161+
block_desc, cache_var, is_forward).dtype()
134162
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
135163
# and dtype_to_size[cache_dtype]
136164
if x_dtype == cache_dtype:
137-
print(
138-
("Hit Cache !!!! cache pool index "
139-
"is %d, var name is %s, "
140-
"cached var name is %s, "
141-
"var shape is %s ") %
142-
(index, x, cache_var, str(cache_shape)))
165+
print(("Hit Cache !!!! cache pool index "
166+
"is %d, var name is %s, "
167+
"cached var name is %s, "
168+
"var shape is %s ") %
169+
(index, x, cache_var,
170+
str(cache_shape)))
143171
self.pool.pop(index)
172+
if x == cache_var:
173+
break
144174
_rename_arg_(
145-
self.ops, x, cache_var, begin_idx=i)
146-
self._program.current_block().var(str(
147-
x)).desc = self.global_block_desc.var(
148-
str(cache_var))
175+
self._ops, x, cache_var, begin_idx=i)
176+
self._program.block(block_desc.id).var(
177+
str(x)).desc = self._find_var(
178+
block_desc, cache_var, is_forward)
149179
self._update_graph(
150180
x, cache_var, begin_idx=i)
151181
break
152182

153183
in_diff, out_diff = self._get_diff(self._live_in[i],
154184
self._live_out[i])
155185
can_optimize = filter(
156-
lambda x: not self.global_block_desc.var(str(x)).persistable(),
186+
lambda x: check_var_validity(block_desc, x, is_forward),
157187
in_diff)
158188
if can_optimize:
159189
for var_name in can_optimize:
160-
self.pool.append(
161-
(var_name,
162-
self.global_block_desc.var(str(var_name)).shape()))
163-
164-
def get_program(self):
165-
return self._program
190+
self.pool.append((var_name, self._find_var(
191+
block_desc, var_name, is_forward).shape()))
192+
193+
194+
def get_cfgs(input_program):
195+
ops_list = []
196+
pdesc = input_program.get_desc()
197+
block_desc = pdesc.block(0)
198+
op_size = block_desc.op_size()
199+
# Get global block ops
200+
ops_list.append(([block_desc.op(i) for i in range(op_size)], op_size))
201+
202+
while_sub_block_ids = []
203+
while_grad_sub_block_ids = []
204+
while_pair = []
205+
206+
for i in range(op_size):
207+
op = block_desc.op(i)
208+
if op.type() == "while":
209+
while_sub_block_ids.append(op.attr("sub_block").id)
210+
elif op.type() == "while_grad":
211+
while_grad_sub_block_ids.append(op.attr("sub_block").id)
212+
213+
# Find while/while_grad block pair
214+
for grad_id in while_grad_sub_block_ids:
215+
parent_id = pdesc.block(grad_id).parent
216+
if parent_id in while_sub_block_ids:
217+
while_pair.append((parent_id, grad_id))
218+
while_sub_block_ids.remove(parent_id)
219+
220+
# Get while/while_grad block ops
221+
for parent_id, grad_id in while_pair:
222+
while_block_ops = []
223+
while_block = pdesc.block(parent_id)
224+
while_block_op_size = while_block.op_size()
225+
for i in range(while_block_op_size):
226+
while_block_ops.append(while_block.op(i))
227+
228+
while_grad_block = pdesc.block(grad_id)
229+
while_grad_block_op_size = while_grad_block.op_size()
230+
for i in range(while_grad_block_op_size):
231+
while_block_ops.append(while_grad_block.op(i))
232+
233+
ops_list.append((while_block_ops, while_block_op_size))
234+
235+
# Process rest while block ops
236+
for parent_id in while_sub_block_ids:
237+
while_block_ops = []
238+
while_block = pdesc.block(parent_id)
239+
while_block_op_size = while_block.op_size()
240+
for i in range(while_block_op_size):
241+
while_block_ops.append(while_block.op(i))
242+
243+
ops_list.append((while_block_ops, while_block_op_size))
244+
245+
cfgs = [ControlFlowGraph(input_program, i, j) for i, j in ops_list]
246+
return cfgs
166247

167248

168249
def memory_optimize(input_program):
169-
graph = ControlFlowGraph(input_program)
170-
graph.memory_optimize()
171-
result_program = graph.get_program()
172-
return result_program
250+
cfgs = get_cfgs(input_program)
251+
for cfg in cfgs:
252+
cfg.memory_optimize()

0 commit comments

Comments
 (0)