Skip to content

Commit 7a9098a

Browse files
reyoungYang Yang(Tony)
authored andcommitted
Add block.fwd_block_id (#8489)
* Add block.fwd_block_id * fix bug in memory optimization transpiler * Change DFS to BFS * Add comments
1 parent 8c0434c commit 7a9098a

File tree

9 files changed

+122
-41
lines changed

9 files changed

+122
-41
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ limitations under the License. */
1616
#include "paddle/fluid/framework/operator.h"
1717
#include "paddle/fluid/framework/program_desc.h"
1818

19+
#include <queue>
20+
1921
namespace paddle {
2022
namespace framework {
2123

@@ -64,12 +66,36 @@ VarDesc *BlockDesc::RenameVar(const std::string &old_name,
6466
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
6567
if (name == kEmptyVarName) return nullptr;
6668

67-
auto it = vars_.find(name);
68-
if (it == vars_.end()) {
69-
return Parent() == kNoneBlockIndex ? nullptr
70-
: ParentBlock()->FindVarRecursive(name);
69+
std::queue<const BlockDesc *> frontier;
70+
std::unordered_set<const BlockDesc *> visited;
71+
72+
frontier.push(this);
73+
74+
while (!frontier.empty()) { // BFS
75+
auto cur = frontier.front();
76+
frontier.pop();
77+
if (visited.count(cur) != 0) {
78+
continue;
79+
}
80+
auto var = cur->FindVar(name);
81+
if (var != nullptr) {
82+
return var;
83+
}
84+
85+
auto fwd = cur->ForwardBlock();
86+
auto parent = cur->ParentBlock();
87+
88+
if (fwd != nullptr) {
89+
frontier.push(fwd);
90+
}
91+
if (parent != nullptr) {
92+
frontier.push(parent);
93+
}
94+
95+
visited.insert(cur);
7196
}
72-
return it->second.get();
97+
98+
return nullptr;
7399
}
74100

75101
VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
@@ -155,10 +181,7 @@ void BlockDesc::Flush() {
155181
}
156182

157183
BlockDesc *BlockDesc::ParentBlock() const {
158-
if (this->desc_->parent_idx() == kNoneBlockIndex) {
159-
return nullptr;
160-
}
161-
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
184+
return prog_->MutableBlock(static_cast<size_t>(desc_->parent_idx()));
162185
}
163186

164187
proto::BlockDesc *BlockDesc::Proto() {
@@ -205,5 +228,16 @@ void BlockDesc::ClearPBVars() {
205228
}
206229
}
207230

231+
void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
232+
PADDLE_ENFORCE(!desc_->has_forward_block_idx(),
233+
"Parent block ID has been set to %d. Cannot set to %d",
234+
desc_->forward_block_idx(), forward_block_id);
235+
desc_->set_forward_block_idx(forward_block_id);
236+
}
237+
238+
BlockDesc *BlockDesc::ForwardBlock() const {
239+
return prog_->MutableBlock(static_cast<size_t>(desc_->forward_block_idx()));
240+
}
241+
208242
} // namespace framework
209243
} // namespace paddle

paddle/fluid/framework/block_desc.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class BlockDesc {
4949

5050
int32_t Parent() const { return desc_->parent_idx(); }
5151

52+
int32_t ForwardBlockID() const { return desc_->forward_block_idx(); }
53+
5254
VarDesc *Var(const std::string &name_bytes);
5355

5456
VarDesc *FindVar(const std::string &name_bytes) const;
@@ -75,6 +77,10 @@ class BlockDesc {
7577

7678
BlockDesc *ParentBlock() const;
7779

80+
BlockDesc *ForwardBlock() const;
81+
82+
void SetForwardBlockID(int32_t forward_block_id);
83+
7884
OpDesc *AppendOp();
7985

8086
void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
@@ -93,7 +99,7 @@ class BlockDesc {
9399

94100
proto::BlockDesc *Proto();
95101

96-
ProgramDesc *Program() { return this->prog_; }
102+
ProgramDesc *Program() const { return this->prog_; }
97103

98104
private:
99105
void ClearPBOps();

paddle/fluid/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ message BlockDesc {
158158
required int32 parent_idx = 2;
159159
repeated VarDesc vars = 3;
160160
repeated OpDesc ops = 4;
161+
optional int32 forward_block_idx = 5 [ default = -1 ];
161162
}
162163

163164
// Please refer to

paddle/fluid/framework/program_desc.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ class ProgramDesc {
3838

3939
BlockDesc *AppendBlock(const BlockDesc &parent);
4040

41-
BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); }
41+
BlockDesc *MutableBlock(size_t idx) {
42+
if (idx == static_cast<size_t>(kNoneBlockIndex)) {
43+
return nullptr;
44+
} else {
45+
return blocks_[idx].get();
46+
}
47+
}
4248

4349
const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; }
4450

paddle/fluid/operators/while_op.cc

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
231231
while_grad->SetInput(kStepScopes, Output(kStepScopes));
232232

233233
auto *grad_block = this->grad_block_[0];
234-
auto *fwd_block = grad_block->ParentBlock();
234+
auto *fwd_block = grad_block->ForwardBlock();
235+
auto *parent_block = grad_block->ParentBlock();
235236

236237
// Not all of IGs will be generated by inner gradient operators of while op.
237238
// Ignore IGs that is not generated by the inside block.
@@ -260,33 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
260261
for (auto &o : Output(kOutputs)) {
261262
block_ins.insert(o);
262263
}
263-
std::unordered_set<std::string> extra_inputs;
264+
std::unordered_set<std::string> output_grads;
264265
for (const auto *op : grad_block->AllOps()) {
265266
for (auto &input_name : op->InputArgumentNames()) {
266267
// If the input of Op has been recorded or is generated by the forward
267268
// block, do not make it as input again.
269+
270+
// The input is located in I/O or other op's outputs or the variable is
271+
// located in grad_block's parents
268272
if (block_ins.find(input_name) != block_ins.end() ||
269-
fwd_block->FindVar(input_name) != nullptr) {
273+
(fwd_block->FindVarRecursive(input_name) != nullptr ||
274+
parent_block->FindVarRecursive(input_name) != nullptr)) {
270275
continue;
271276
}
272-
extra_inputs.insert(input_name);
277+
output_grads.insert(input_name);
273278
}
274279
for (auto &output_name : op->OutputArgumentNames()) {
275280
block_ins.insert(output_name);
276281
}
277282
}
278283

279-
std::vector<std::string> extra_inputs_list;
280-
extra_inputs_list.resize(extra_inputs.size());
281-
std::copy(extra_inputs.begin(), extra_inputs.end(),
282-
extra_inputs_list.begin());
283-
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
284+
std::vector<std::string> output_grads_list;
285+
output_grads_list.resize(output_grads.size());
286+
std::copy(output_grads.begin(), output_grads.end(),
287+
output_grads_list.begin());
288+
while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
284289

285290
while_grad->SetAttrMap(this->Attrs());
286291
while_grad->SetBlockAttr(kStepBlock, *grad_block);
287292
// record the original output gradient names, since the gradient name of
288293
// while operator could be renamed.
289-
while_grad->SetAttr("original_output_grad", extra_inputs_list);
294+
while_grad->SetAttr("original_output_grad", output_grads_list);
290295

291296
return std::unique_ptr<framework::OpDesc>(while_grad);
292297
}

paddle/fluid/pybind/protobuf.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) {
155155
py::class_<BlockDesc>(m, "BlockDesc", "")
156156
.def_property_readonly("id", &BlockDesc::ID)
157157
.def_property_readonly("parent", &BlockDesc::Parent)
158+
.def("get_forward_block_idx", &BlockDesc::ForwardBlockID)
159+
.def("set_forward_block_idx", &BlockDesc::SetForwardBlockID)
158160
.def("append_op", &BlockDesc::AppendOp,
159161
py::return_value_policy::reference)
160162
.def("prepend_op", &BlockDesc::PrependOp,

python/paddle/v2/fluid/backward.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ def _append_backward_ops_(block,
298298
# If the op has its own sub-block, deal with the sub-block first
299299
if op.has_attr("sub_block"):
300300
sub_block = program.block(op.block_attr("sub_block"))
301-
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
301+
grad_sub_block = program.create_block()
302+
grad_sub_block.set_forward_block_idx(sub_block.idx)
302303
cb = _callback_lookup_(op)
303304
if cb is not None:
304305
if callbacks is None:
@@ -310,6 +311,8 @@ def _append_backward_ops_(block,
310311
else:
311312
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
312313
no_grad_dict, grad_to_var, callbacks)
314+
315+
program.rollback()
313316
grad_sub_block_list.append(grad_sub_block.desc)
314317

315318
# Getting op's corresponding grad_op

python/paddle/v2/fluid/framework.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,13 @@ def to_string(self, throw_on_error, with_details=False):
696696
def parent_idx(self):
697697
return self.desc.parent
698698

699+
@property
700+
def forward_block_idx(self):
701+
return self.desc.get_forward_block_idx()
702+
703+
def set_forward_block_idx(self, idx):
704+
self.desc.set_forward_block_idx(idx)
705+
699706
@property
700707
def idx(self):
701708
return self.desc.id
@@ -709,15 +716,32 @@ def var(self, name):
709716
return v
710717

711718
def var_recursive(self, name):
712-
if self.has_var(name):
713-
return self.var(name)
714-
else:
715-
if self.idx == 0:
716-
raise ValueError("var %s is not in block(%d) nor its parents." %
717-
name, self.idx)
718-
else:
719-
parent_block = self.program.block(self.parent_idx)
720-
return parent_block.var_recursive(name)
719+
frontier = list()
720+
visited = set()
721+
722+
frontier.append(self)
723+
724+
prog = self.program
725+
726+
while len(frontier) != 0: # BFS
727+
cur = frontier[0]
728+
frontier = frontier[1:]
729+
730+
if id(cur) in visited:
731+
continue
732+
733+
if cur.has_var(name):
734+
return cur.var(name)
735+
736+
if cur.parent_idx != -1:
737+
frontier.append(prog.block(cur.parent_idx))
738+
739+
if cur.forward_block_idx != -1:
740+
frontier.append(prog.block(cur.forward_block_idx))
741+
742+
visited.add(id(cur))
743+
744+
raise ValueError("Var {0} is not found recursively".format(name))
721745

722746
def all_parameters(self):
723747
return list(self.iter_parameters())

python/paddle/v2/fluid/memory_optimization_transpiler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,15 @@ def get_cfgs(input_program):
223223

224224
# Find while/while_grad block pair
225225
for grad_id in while_grad_sub_block_ids:
226-
parent_id = pdesc.block(grad_id).parent
227-
if parent_id in while_sub_block_ids:
228-
while_block_id_pair.append((parent_id, grad_id))
229-
while_sub_block_ids.remove(parent_id)
226+
forward_id = pdesc.block(grad_id).get_forward_block_idx()
227+
if forward_id in while_sub_block_ids:
228+
while_block_id_pair.append((forward_id, grad_id))
229+
while_sub_block_ids.remove(forward_id)
230230

231231
# Get while/while_grad block ops
232-
for parent_id, grad_id in while_block_id_pair:
232+
for forward_id, grad_id in while_block_id_pair:
233233
while_block_ops = []
234-
while_block = pdesc.block(parent_id)
234+
while_block = pdesc.block(forward_id)
235235
while_block_op_size = while_block.op_size()
236236
for i in range(while_block_op_size):
237237
while_block_ops.append(while_block.op(i))
@@ -242,21 +242,21 @@ def get_cfgs(input_program):
242242
while_block_ops.append(while_grad_block.op(i))
243243

244244
while_op_output = set()
245-
while_op_output.update(while_op_dict[parent_id].output_arg_names())
245+
while_op_output.update(while_op_dict[forward_id].output_arg_names())
246246
while_op_output.update(while_op_dict[grad_id].output_arg_names())
247247

248248
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
249249

250250
# Process rest while block ops
251-
for parent_id in while_sub_block_ids:
251+
for forward_id in while_sub_block_ids:
252252
while_block_ops = []
253-
while_block = pdesc.block(parent_id)
253+
while_block = pdesc.block(forward_id)
254254
while_block_op_size = while_block.op_size()
255255
for i in range(while_block_op_size):
256256
while_block_ops.append(while_block.op(i))
257257

258258
while_op_output = set()
259-
while_op_output.update(while_op_dict[parent_id].output_arg_names())
259+
while_op_output.update(while_op_dict[forward_id].output_arg_names())
260260

261261
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
262262

0 commit comments

Comments
 (0)