Skip to content

Commit 1363ddb

Browse files
authored
Feature/executor use program bind (#5196)
* Init commit * Make executor use ProgramDescBind * Change Attribute from BlockDesc to BlockDescBind * Since we will get the program desc in RNN, just BlockDesc is not enough.
1 parent ee11f00 commit 1363ddb

20 files changed

+94
-92
lines changed

paddle/framework/attribute.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace framework {
2121

22-
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
22+
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
2323
switch (attr_desc.type()) {
2424
case framework::AttrType::BOOLEAN: {
2525
return attr_desc.b();
@@ -61,13 +61,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
6161
}
6262
return val;
6363
}
64-
case framework::AttrType::BLOCK: {
65-
PADDLE_ENFORCE(program != nullptr,
66-
"Need to specify ProgramDesc when get a block attr");
67-
return program->mutable_blocks(attr_desc.block_idx());
68-
}
64+
default:
65+
PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
6966
}
70-
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
7167
return boost::blank();
7268
}
7369

paddle/framework/attribute.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ inline AttrType AttrTypeID() {
3232
return static_cast<AttrType>(tmp.which() - 1);
3333
}
3434

35-
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc);
35+
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
3636

3737
class AttrReader {
3838
public:

paddle/framework/backward.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
368368
ProgramDescBind& program_desc, int block_idx,
369369
std::unordered_set<std::string>* no_grad_vars,
370370
std::unordered_map<std::string, std::string>* grad_to_var) {
371-
BlockDescBind* cur_block = program_desc.Block(block_idx);
371+
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
372372
std::vector<OpDescBind*> op_descs = cur_block->AllOps();
373373
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
374374
size_t grad_desc_idx = 0;
@@ -443,7 +443,7 @@ ParamGradInfoMap AppendBackward(
443443
}
444444

445445
const int root_block_idx = 0;
446-
auto root_block = program_desc.Block(root_block_idx);
446+
auto root_block = program_desc.MutableBlock(root_block_idx);
447447

448448
// insert fill one op for target
449449
// TODO(qiao) add some check to the target.
@@ -492,7 +492,7 @@ ParamGradInfoMap AppendBackward(
492492
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
493493
for (size_t block_index = forward_block_num;
494494
block_index < program_desc.Size(); ++block_index) {
495-
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
495+
CreateGradVarInBlock(0, grad_to_var, program_desc.MutableBlock(block_index),
496496
&retv);
497497
}
498498
return retv;

paddle/framework/backward_test.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
499499

500500
TEST(Backward, simple_single_op) {
501501
f::ProgramDescBind program;
502-
f::BlockDescBind *block = program.Block(0);
502+
f::BlockDescBind *block = program.MutableBlock(0);
503503

504504
f::OpDescBind *op = block->AppendOp();
505505
op->SetType("rowwise_add");
@@ -535,7 +535,7 @@ TEST(Backward, simple_single_op) {
535535

536536
TEST(Backward, default_attribute) {
537537
f::ProgramDescBind program;
538-
f::BlockDescBind *block = program.Block(0);
538+
f::BlockDescBind *block = program.MutableBlock(0);
539539
f::OpDescBind *op = block->AppendOp();
540540
op->SetType("mul");
541541
op->SetInput("X", {"x"});
@@ -561,7 +561,7 @@ TEST(Backward, default_attribute) {
561561

562562
TEST(Backward, simple_mult_op) {
563563
f::ProgramDescBind program;
564-
f::BlockDescBind *block = program.Block(0);
564+
f::BlockDescBind *block = program.MutableBlock(0);
565565
f::OpDescBind *op1 = block->AppendOp();
566566
op1->SetType("rowwise_add");
567567
op1->SetInput("X", {"x1"});
@@ -644,7 +644,7 @@ TEST(Backward, simple_mult_op) {
644644

645645
TEST(Backward, intermedia_var_no_grad) {
646646
f::ProgramDescBind program;
647-
f::BlockDescBind *block = program.Block(0);
647+
f::BlockDescBind *block = program.MutableBlock(0);
648648
f::OpDescBind *op1 = block->AppendOp();
649649
op1->SetType("rowwise_add");
650650
op1->SetInput("X", {"x1"});
@@ -714,7 +714,7 @@ TEST(Backward, intermedia_var_no_grad) {
714714

715715
TEST(Backward, var_no_grad) {
716716
f::ProgramDescBind program;
717-
f::BlockDescBind *block = program.Block(0);
717+
f::BlockDescBind *block = program.MutableBlock(0);
718718
f::OpDescBind *op1 = block->AppendOp();
719719
op1->SetType("mult_in_out");
720720
op1->SetInput("X", {"x1"});
@@ -790,7 +790,7 @@ TEST(Backward, var_no_grad) {
790790

791791
TEST(Backward, shared_var) {
792792
f::ProgramDescBind program;
793-
f::BlockDescBind *block = program.Block(0);
793+
f::BlockDescBind *block = program.MutableBlock(0);
794794
f::OpDescBind *op1 = block->AppendOp();
795795
op1->SetType("rowwise_add");
796796
op1->SetInput("X", {"x1"});
@@ -880,7 +880,7 @@ TEST(Backward, shared_var) {
880880

881881
TEST(Backward, half_backward) {
882882
f::ProgramDescBind program;
883-
f::BlockDescBind *block = program.Block(0);
883+
f::BlockDescBind *block = program.MutableBlock(0);
884884
auto *op1 = block->AppendOp();
885885
op1->SetType("minus");
886886
op1->SetInput("X", {"a"});

paddle/framework/block_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
113113
if (this->desc_->parent_idx() == kNoneBlockIndex) {
114114
return nullptr;
115115
}
116-
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
116+
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
117117
}
118118

119119
BlockDesc *BlockDescBind::Proto() {

paddle/framework/executor.cc

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,33 +73,32 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
7373
}
7474
}
7575

76-
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
76+
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id) {
7777
// TODO(tonyyang-svail):
7878
// - only runs on the first device (i.e. no interdevice communication)
7979
// - will change to use multiple blocks for RNN op and Cond Op
80-
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id);
81-
auto& block = pdesc.blocks(block_id);
80+
PADDLE_ENFORCE_LT(block_id, pdesc.Size());
81+
auto& block = pdesc.Block(block_id);
8282
auto& device = device_contexts_[0];
8383

8484
Scope& local_scope = scope->NewScope();
8585

86-
for (auto& var : block.vars()) {
87-
if (var.persistable()) {
88-
auto* ptr = scope->Var(var.name());
89-
CreateTensor(ptr, var.type());
90-
VLOG(3) << "Create Variable " << var.name()
86+
for (auto& var : block.AllVars()) {
87+
if (var->Persistable()) {
88+
auto* ptr = scope->Var(var->Name());
89+
CreateTensor(ptr, var->GetType());
90+
VLOG(3) << "Create Variable " << var->Name()
9191
<< " global, which pointer is " << ptr;
9292
} else {
93-
auto* ptr = local_scope.Var(var.name());
94-
CreateTensor(ptr, var.type());
95-
VLOG(3) << "Create Variable " << var.name()
93+
auto* ptr = local_scope.Var(var->Name());
94+
CreateTensor(ptr, var->GetType());
95+
VLOG(3) << "Create Variable " << var->Name()
9696
<< " locally, which pointer is " << ptr;
9797
}
9898
}
9999

100-
for (auto& op_desc : block.ops()) {
101-
auto op = paddle::framework::OpRegistry::CreateOp(
102-
op_desc, const_cast<ProgramDesc*>(&pdesc));
100+
for (auto& op_desc : block.AllOps()) {
101+
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
103102
op->Run(local_scope, *device);
104103
}
105104

paddle/framework/executor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include "paddle/framework/framework.pb.h"
1817
#include "paddle/framework/op_info.h"
18+
#include "paddle/framework/program_desc.h"
1919
#include "paddle/framework/scope.h"
2020
#include "paddle/framework/tensor.h"
2121

@@ -34,7 +34,7 @@ class Executor {
3434
* ProgramDesc
3535
* Scope
3636
*/
37-
void Run(const ProgramDesc&, Scope*, int);
37+
void Run(const ProgramDescBind&, Scope*, int);
3838

3939
private:
4040
std::vector<platform::DeviceContext*> device_contexts_;

paddle/framework/op_desc.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,12 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
114114
// restore attrs_
115115
for (const OpDesc::Attr &attr : desc_.attrs()) {
116116
std::string attr_name = attr.name();
117-
attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
117+
if (attr.type() != AttrType::BLOCK) {
118+
attrs_[attr_name] = GetAttrValue(attr);
119+
} else {
120+
auto bid = attr.block_idx();
121+
attrs_[attr_name] = prog->MutableBlock(bid);
122+
}
118123
}
119124
}
120125

@@ -188,8 +193,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
188193
}
189194

190195
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
191-
BlockDesc *desc = block.Proto();
192-
this->attrs_[name] = desc;
196+
this->attrs_[name] = &block;
193197
need_update_ = true;
194198
}
195199

@@ -208,7 +212,7 @@ Attribute OpDescBind::GetAttr(const std::string &name) const {
208212
int OpDescBind::GetBlockAttr(const std::string &name) const {
209213
auto it = attrs_.find(name);
210214
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
211-
return boost::get<BlockDesc *>(it->second)->idx();
215+
return boost::get<BlockDescBind *>(it->second)->ID();
212216
}
213217

214218
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()

paddle/framework/op_registry.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
4343
return ret_val;
4444
}
4545

46-
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc,
47-
ProgramDesc* program) {
46+
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
47+
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
48+
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
49+
"instead.";
4850
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
4951
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
5052
AttributeMap attrs;
5153
for (auto& attr : op_desc.attrs()) {
52-
attrs[attr.name()] = GetAttrValue(attr, program);
54+
attrs[attr.name()] = GetAttrValue(attr);
5355
}
5456

5557
return CreateOp(op_desc.type(), inputs, outputs, attrs);

paddle/framework/op_registry.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ class OpRegistry {
7777
const VariableNameMap& outputs,
7878
AttributeMap attrs);
7979

80-
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc,
81-
ProgramDesc* program);
80+
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
8281

8382
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
8483
};

0 commit comments

Comments
 (0)