Skip to content

Commit 14f8370

Browse files
committed
Add block.fwd_block_id
1 parent 78cc64a commit 14f8370

File tree

8 files changed

+78
-17
lines changed

8 files changed

+78
-17
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,25 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
4646
if (name == kEmptyVarName) return nullptr;
4747

4848
auto it = vars_.find(name);
49-
if (it == vars_.end()) {
50-
return Parent() == kNoneBlockIndex ? nullptr
51-
: ParentBlock()->FindVarRecursive(name);
49+
if (it != vars_.end()) {
50+
return it->second.get();
5251
}
53-
return it->second.get();
52+
53+
BlockDesc *tmp = ParentBlock();
54+
55+
if (tmp != nullptr) {
56+
auto ptr = tmp->FindVarRecursive(name);
57+
if (ptr != nullptr) {
58+
return ptr;
59+
}
60+
}
61+
62+
tmp = ForwardBlock();
63+
if (tmp != nullptr) {
64+
return tmp->FindVarRecursive(name);
65+
}
66+
67+
return nullptr;
5468
}
5569

5670
VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
@@ -136,10 +150,7 @@ void BlockDesc::Flush() {
136150
}
137151

138152
BlockDesc *BlockDesc::ParentBlock() const {
139-
if (this->desc_->parent_idx() == kNoneBlockIndex) {
140-
return nullptr;
141-
}
142-
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
153+
return prog_->MutableBlock(static_cast<size_t>(desc_->parent_idx()));
143154
}
144155

145156
proto::BlockDesc *BlockDesc::Proto() {
@@ -186,5 +197,16 @@ void BlockDesc::ClearPBVars() {
186197
}
187198
}
188199

200+
void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
201+
PADDLE_ENFORCE(!desc_->has_forward_block_idx(),
202+
"Parent block ID has been set to %d. Cannot set to %d",
203+
desc_->forward_block_idx(), forward_block_id);
204+
desc_->set_forward_block_idx(forward_block_id);
205+
}
206+
207+
BlockDesc *BlockDesc::ForwardBlock() const {
208+
return prog_->MutableBlock(static_cast<size_t>(desc_->forward_block_idx()));
209+
}
210+
189211
} // namespace framework
190212
} // namespace paddle

paddle/fluid/framework/block_desc.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class BlockDesc {
4949

5050
int32_t Parent() const { return desc_->parent_idx(); }
5151

52+
int32_t ForwardBlockID() const { return desc_->forward_block_idx(); }
53+
5254
VarDesc *Var(const std::string &name_bytes);
5355

5456
VarDesc *FindVar(const std::string &name_bytes) const;
@@ -73,6 +75,10 @@ class BlockDesc {
7375

7476
BlockDesc *ParentBlock() const;
7577

78+
BlockDesc *ForwardBlock() const;
79+
80+
void SetForwardBlockID(int32_t forward_block_id);
81+
7682
OpDesc *AppendOp();
7783

7884
void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
@@ -91,7 +97,7 @@ class BlockDesc {
9197

9298
proto::BlockDesc *Proto();
9399

94-
ProgramDesc *Program() { return this->prog_; }
100+
ProgramDesc *Program() const { return this->prog_; }
95101

96102
private:
97103
void ClearPBOps();

paddle/fluid/framework/framework.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ message BlockDesc {
158158
required int32 parent_idx = 2;
159159
repeated VarDesc vars = 3;
160160
repeated OpDesc ops = 4;
161+
optional int32 forward_block_idx = 5 [ default = -1 ];
161162
}
162163

163164
// Please refer to

paddle/fluid/framework/program_desc.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ class ProgramDesc {
3838

3939
BlockDesc *AppendBlock(const BlockDesc &parent);
4040

41-
BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); }
41+
BlockDesc *MutableBlock(size_t idx) {
42+
if (idx == static_cast<size_t>(kNoneBlockIndex)) {
43+
return nullptr;
44+
} else {
45+
return blocks_[idx].get();
46+
}
47+
}
4248

4349
const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; }
4450

paddle/fluid/operators/while_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
231231
while_grad->SetInput(kStepScopes, Output(kStepScopes));
232232

233233
auto *grad_block = this->grad_block_[0];
234-
auto *fwd_block = grad_block->ParentBlock();
234+
auto *fwd_block = grad_block->ForwardBlock();
235+
auto *parent_block = grad_block->ParentBlock();
235236

236237
// Not all of IGs will be generated by inner gradient operators of while op.
237238
// Ignore IGs that is not generated by the inside block.
@@ -265,8 +266,10 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
265266
for (auto &input_name : op->InputArgumentNames()) {
266267
// If the input of Op has been recorded or is generated by the forward
267268
// block, do not make it as input again.
269+
268270
if (block_ins.find(input_name) != block_ins.end() ||
269-
fwd_block->FindVar(input_name) != nullptr) {
271+
fwd_block->FindVar(input_name) != nullptr ||
272+
parent_block->FindVar(input_name) != nullptr) {
270273
continue;
271274
}
272275
extra_inputs.insert(input_name);

paddle/fluid/pybind/protobuf.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) {
155155
py::class_<BlockDesc>(m, "BlockDesc", "")
156156
.def_property_readonly("id", &BlockDesc::ID)
157157
.def_property_readonly("parent", &BlockDesc::Parent)
158+
.def("get_forward_block_idx", &BlockDesc::ForwardBlockID)
159+
.def("set_forward_block_idx", &BlockDesc::SetForwardBlockID)
158160
.def("append_op", &BlockDesc::AppendOp,
159161
py::return_value_policy::reference)
160162
.def("prepend_op", &BlockDesc::PrependOp,

python/paddle/v2/fluid/backward.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ def _append_backward_ops_(block,
298298
# If the op has its own sub-block, deal with the sub-block first
299299
if op.has_attr("sub_block"):
300300
sub_block = program.block(op.block_attr("sub_block"))
301-
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
301+
grad_sub_block = program.create_block()
302+
grad_sub_block.set_forward_block_idx(sub_block.idx)
302303
cb = _callback_lookup_(op)
303304
if cb is not None:
304305
if callbacks is None:
@@ -310,6 +311,8 @@ def _append_backward_ops_(block,
310311
else:
311312
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
312313
no_grad_dict, grad_to_var, callbacks)
314+
315+
program.rollback()
313316
grad_sub_block_list.append(grad_sub_block.desc)
314317

315318
# Getting op's corresponding grad_op

python/paddle/v2/fluid/framework.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,13 @@ def to_string(self, throw_on_error, with_details=False):
678678
def parent_idx(self):
679679
return self.desc.parent
680680

681+
@property
682+
def forward_block_idx(self):
683+
return self.desc.get_forward_block_idx()
684+
685+
def set_forward_block_idx(self, idx):
686+
self.desc.set_forward_block_idx(idx)
687+
681688
@property
682689
def idx(self):
683690
return self.desc.id
@@ -695,11 +702,22 @@ def var_recursive(self, name):
695702
return self.var(name)
696703
else:
697704
if self.idx == 0:
698-
raise ValueError("var %s is not in block(%d) nor its parents." %
699-
name, self.idx)
705+
raise ValueError(
706+
"var {0} is not in block({1}) nor its parents.".format(
707+
name, self.idx))
700708
else:
701-
parent_block = self.program.block(self.parent_idx)
702-
return parent_block.var_recursive(name)
709+
# DFS
710+
try:
711+
parent_block = self.program.block(self.parent_idx)
712+
return parent_block.var_recursive(name)
713+
except ValueError:
714+
fwd_block = self.program.block(
715+
self.forward_block_idx
716+
) if self.forward_block_idx != -1 else None
717+
if fwd_block is not None:
718+
return fwd_block.var_recursive(name)
719+
else:
720+
raise
703721

704722
def all_parameters(self):
705723
return list(self.iter_parameters())

0 commit comments

Comments
 (0)