Skip to content

Commit d6997e5

Browse files
authored
Merge pull request #11083 from JiayiFeng/dev_refine_programdesc_copy
Refine ProgramDesc copy
2 parents 85c203b + 31f0533 commit d6997e5

File tree

8 files changed

+28
-14
lines changed

8 files changed

+28
-14
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
200200
vars_[var_desc.name()].reset(new VarDesc(var_desc));
201201
}
202202
for (const proto::OpDesc &op_desc : desc_->ops()) {
203-
ops_.emplace_back(new OpDesc(op_desc, prog, this));
203+
ops_.emplace_back(new OpDesc(op_desc, this));
204204
}
205205
}
206206

@@ -209,7 +209,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
209209
: prog_(prog), desc_(desc) {
210210
need_update_ = true;
211211
for (auto &op : other.ops_) {
212-
ops_.emplace_back(new OpDesc(*op->Proto(), prog, this));
212+
ops_.emplace_back(new OpDesc(*op, this));
213213
}
214214
for (auto &it : other.vars_) {
215215
auto *var = new VarDesc(*it.second);

paddle/fluid/framework/block_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class BlockDesc {
105105

106106
size_t OpSize() const { return ops_.size(); }
107107

108-
OpDesc *Op(int idx) { return ops_.at(idx).get(); }
108+
OpDesc *Op(int idx) const { return ops_.at(idx).get(); }
109109

110110
void Flush();
111111

paddle/fluid/framework/op_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
103103
need_update_ = true;
104104
}
105105

106-
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block)
106+
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
107107
: desc_(desc), need_update_(false) {
108108
// restore inputs_
109109
int input_size = desc_.inputs_size();

paddle/fluid/framework/op_desc.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ class OpDesc {
3333
OpDesc(const std::string &type, const VariableNameMap &inputs,
3434
const VariableNameMap &outputs, const AttributeMap &attrs);
3535

36-
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block);
36+
OpDesc(const proto::OpDesc &desc, BlockDesc *block);
3737

3838
explicit OpDesc(BlockDesc *block) : block_(block) {}
3939

4040
OpDesc(const OpDesc &other, BlockDesc *block) {
4141
*this = other;
4242
block_ = block;
43+
need_update_ = true;
4344
}
4445

4546
void CopyFrom(const OpDesc &op_desc);

paddle/fluid/framework/program_desc.cc

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,15 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
5151
auto *block = desc_.mutable_blocks(i);
5252
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
5353
}
54-
for (auto &block : blocks_) {
55-
for (auto *op : block->AllOps()) {
56-
for (const auto &attr : op->Proto()->attrs()) {
57-
if (attr.type() == proto::AttrType::BLOCK) {
58-
size_t blk_idx = attr.block_idx();
59-
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
54+
for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) {
55+
auto all_ops = blocks_[block_id]->AllOps();
56+
for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) {
57+
auto &op = all_ops[op_id];
58+
for (const std::string &attr_name : op->AttrNames()) {
59+
if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
60+
int sub_block_id =
61+
o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name);
62+
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
6063
}
6164
}
6265
}
@@ -86,6 +89,16 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
8689
for (auto &block_desc : *desc_.mutable_blocks()) {
8790
blocks_.emplace_back(new BlockDesc(this, &block_desc));
8891
}
92+
for (auto &block : blocks_) {
93+
for (auto *op : block->AllOps()) {
94+
for (const auto &attr : op->Proto()->attrs()) {
95+
if (attr.type() == proto::AttrType::BLOCK) {
96+
size_t blk_idx = attr.block_idx();
97+
op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
98+
}
99+
}
100+
}
101+
}
89102
}
90103

91104
const std::vector<std::string> ProgramDesc::GetFeedTargetNames() {

paddle/fluid/inference/tensorrt/convert/activation_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ReluOpConverter : public OpConverter {
2424
void operator()(const framework::proto::OpDesc& op) override {
2525
// Here the two nullptr looks strange, that's because the
2626
// framework::OpDesc's constructor is strange.
27-
framework::OpDesc op_desc(op, nullptr, nullptr);
27+
framework::OpDesc op_desc(op, nullptr);
2828
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
2929
"type is Relu";
3030
const nvinfer1::ITensor* input_tensor =

paddle/fluid/inference/tensorrt/convert/mul_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class MulOpConverter : public OpConverter {
2727
void operator()(const framework::proto::OpDesc& op) override {
2828
VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias";
2929

30-
framework::OpDesc op_desc(op, nullptr, nullptr);
30+
framework::OpDesc op_desc(op, nullptr);
3131
// Declare inputs
3232
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
3333
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);

paddle/fluid/inference/tensorrt/convert/ut_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class TRTConvertValidation {
104104
engine_->FreezeNetwork();
105105

106106
// Declare outputs.
107-
op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr));
107+
op_desc_.reset(new framework::OpDesc(desc, nullptr));
108108

109109
// Set Inputs.
110110
for (const auto& input : op_desc_->InputArgumentNames()) {

0 commit comments

Comments
 (0)