Skip to content

Commit 2a3c58d

Browse files
committed
refine programdesc copy
1 parent ccf61b3 commit 2a3c58d

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
209209
: prog_(prog), desc_(desc) {
210210
need_update_ = true;
211211
for (auto &op : other.ops_) {
212-
ops_.emplace_back(new OpDesc(*op->Proto(), prog, this));
212+
ops_.emplace_back(new OpDesc(*op, this));
213213
}
214214
for (auto &it : other.vars_) {
215215
auto *var = new VarDesc(*it.second);

paddle/fluid/framework/block_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class BlockDesc {
105105

106106
size_t OpSize() const { return ops_.size(); }
107107

108-
OpDesc *Op(int idx) { return ops_.at(idx).get(); }
108+
OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
109109

110110
void Flush();
111111

paddle/fluid/framework/program_desc.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
5151
auto *block = desc_.mutable_blocks(i);
5252
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
5353
}
54-
for (auto &block : blocks_) {
55-
for (auto *op : block->AllOps()) {
56-
for (const auto &attr : op->Proto()->attrs()) {
57-
if (attr.type() == proto::AttrType::BLOCK) {
58-
size_t blk_idx = attr.block_idx();
59-
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
54+
for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) {
55+
auto all_ops = blocks_[block_id]->AllOps();
56+
for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
57+
auto &op = all_ops[op_id];
58+
for (const std::string &attr_name : op->AttrNames()) {
59+
if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
60+
int sub_block_id =
61+
o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name);
62+
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
6063
}
6164
}
6265
}

0 commit comments

Comments
 (0)