Skip to content

Commit a3aca2a

Browse files
committed
fix bugs
1 parent 2a3c58d commit a3aca2a

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
200200
vars_[var_desc.name()].reset(new VarDesc(var_desc));
201201
}
202202
for (const proto::OpDesc &op_desc : desc_->ops()) {
203-
ops_.emplace_back(new OpDesc(op_desc, prog, this));
203+
ops_.emplace_back(new OpDesc(op_desc, this));
204204
}
205205
}
206206

paddle/fluid/framework/op_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
103103
need_update_ = true;
104104
}
105105

106-
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block)
106+
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
107107
: desc_(desc), need_update_(false) {
108108
// restore inputs_
109109
int input_size = desc_.inputs_size();

paddle/fluid/framework/op_desc.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ class OpDesc {
3333
OpDesc(const std::string &type, const VariableNameMap &inputs,
3434
const VariableNameMap &outputs, const AttributeMap &attrs);
3535

36-
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block);
36+
OpDesc(const proto::OpDesc &desc, BlockDesc *block);
3737

3838
explicit OpDesc(BlockDesc *block) : block_(block) {}
3939

4040
OpDesc(const OpDesc &other, BlockDesc *block) {
4141
*this = other;
4242
block_ = block;
43+
need_update_ = true;
4344
}
4445

4546
void CopyFrom(const OpDesc &op_desc);

paddle/fluid/framework/program_desc.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
8989
for (auto &block_desc : *desc_.mutable_blocks()) {
9090
blocks_.emplace_back(new BlockDesc(this, &block_desc));
9191
}
92+
for (auto &block : blocks_) {
93+
for (auto *op : block->AllOps()) {
94+
for (const auto &attr : op->Proto()->attrs()) {
95+
if (attr.type() == proto::AttrType::BLOCK) {
96+
size_t blk_idx = attr.block_idx();
97+
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
98+
}
99+
}
100+
}
101+
}
92102
}
93103

94104
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {

0 commit comments

Comments
 (0)