Skip to content

Commit 488320a

Browse files
reyoungQiJune
authored andcommitted
Conditional Block Forward (#5530)
* Conditional Block Forward * Assign Operator. Out=X, when type in [LoDTensor/SelectedRows/LoDTensorArray] * Stash * Add Scope::Rename it is useful in gradient phase of an operator with block * ConditionalBlock Grad Done * Add comments * yapf format code
1 parent f07a226 commit 488320a

File tree

5 files changed

+350
-7
lines changed

5 files changed

+350
-7
lines changed

paddle/framework/backward.cc

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,12 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
377377
return grad_op_descs;
378378
}
379379

380+
static BlockDescBind* CreateStepBlock(
381+
ProgramDescBind& program_desc,
382+
std::unordered_set<std::string>* no_grad_vars,
383+
std::unordered_map<std::string, std::string>* grad_to_var,
384+
int step_block_idx);
385+
380386
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
381387
ProgramDescBind& program_desc, int block_idx,
382388
std::unordered_set<std::string>* no_grad_vars,
@@ -392,13 +398,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
392398

393399
if ((*it)->Type() == "recurrent") {
394400
int step_block_idx = (*it)->GetBlockAttr("step_block");
395-
auto backward_block_op_descs = MakeBlockBackward(
396-
program_desc, step_block_idx, no_grad_vars, grad_to_var);
401+
BlockDescBind* backward_block = CreateStepBlock(
402+
program_desc, no_grad_vars, grad_to_var, step_block_idx);
403+
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
404+
} else if ((*it)->Type() == "conditional_block") {
397405
BlockDescBind* backward_block =
398-
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
399-
for (auto& ptr : backward_block_op_descs) {
400-
backward_block->AppendAllocatedOp(std::move(ptr));
401-
}
406+
CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
407+
(*it)->GetBlockAttr("block"));
402408
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
403409
} else {
404410
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
@@ -449,6 +455,21 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
449455
return backward_descs;
450456
}
451457

458+
static BlockDescBind* CreateStepBlock(
459+
ProgramDescBind& program_desc,
460+
std::unordered_set<std::string>* no_grad_vars,
461+
std::unordered_map<std::string, std::string>* grad_to_var,
462+
int step_block_idx) {
463+
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
464+
no_grad_vars, grad_to_var);
465+
BlockDescBind* backward_block =
466+
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
467+
for (auto& ptr : backward_block_op_descs) {
468+
backward_block->AppendAllocatedOp(move(ptr));
469+
}
470+
return backward_block;
471+
}
472+
452473
ParamGradInfoMap AppendBackward(
453474
ProgramDescBind& program_desc, const VarDescBind& target,
454475
const std::unordered_set<std::string>& no_grad_vars) {
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include <algorithm>
15+
#include "paddle/framework/executor.h"
16+
#include "paddle/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ConditionalOp : public framework::OperatorBase {
22+
public:
23+
ConditionalOp(const std::string &type,
24+
const framework::VariableNameMap &inputs,
25+
const framework::VariableNameMap &outputs,
26+
const framework::AttributeMap &attrs)
27+
: OperatorBase(type, inputs, outputs, attrs) {}
28+
29+
protected:
30+
std::vector<const framework::LoDTensor *> InputTensors(
31+
const framework::Scope &scope) const {
32+
std::vector<const framework::LoDTensor *> retv;
33+
auto xs = Inputs("X");
34+
retv.resize(xs.size(), nullptr);
35+
std::transform(
36+
xs.begin(), xs.end(), retv.begin(),
37+
[&scope](const std::string &var_name) -> const framework::LoDTensor * {
38+
auto *var = scope.FindVar(var_name);
39+
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", var_name);
40+
return &var->Get<framework::LoDTensor>();
41+
});
42+
return retv;
43+
}
44+
};
45+
46+
class ConditionalBlockOp : public ConditionalOp {
47+
public:
48+
ConditionalBlockOp(const std::string &type,
49+
const framework::VariableNameMap &inputs,
50+
const framework::VariableNameMap &outputs,
51+
const framework::AttributeMap &attrs)
52+
: ConditionalOp(type, inputs, outputs, attrs) {}
53+
void Run(const framework::Scope &scope,
54+
const platform::DeviceContext &dev_ctx) const override {
55+
auto xs = InputTensors(scope);
56+
bool need_run = std::all_of(
57+
xs.begin(), xs.end(),
58+
[](const framework::LoDTensor *t) { return t->numel() != 0; });
59+
60+
if (need_run) {
61+
auto *scope_var = scope.FindVar(Output("Scope"));
62+
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
63+
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
64+
scopes->resize(1);
65+
scopes->front() = &scope.NewScope();
66+
auto &cur_scope = *scopes->front();
67+
68+
auto *block = Attr<framework::BlockDescBind *>("block");
69+
framework::Executor exec(dev_ctx);
70+
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
71+
}
72+
}
73+
};
74+
75+
class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
76+
public:
77+
ConditionalBlockOpProtoMaker(framework::OpProto *proto,
78+
framework::OpAttrChecker *op_checker)
79+
: OpProtoAndCheckerMaker(proto, op_checker) {
80+
AddInput("X",
81+
"The conditional variable of this operator. If X is empty, the "
82+
"whole sub-block will not be executed.")
83+
.AsDuplicable();
84+
AddInput("Params", "The input variables of the sub-block.").AsDuplicable();
85+
AddOutput("Out", "The output variables of the sub-block.").AsDuplicable();
86+
AddOutput("Scope",
87+
"(std::vector<Scope*>) The step scope of conditional block. To "
88+
"unify the conditional block, rnn and while op, the type of "
89+
"scope is std::vector<Scope*>");
90+
AddAttr<framework::BlockDescBind *>(
91+
"block", "The step block of conditional block operator");
92+
AddComment(R"DOC(Conditional block operator
93+
94+
Run the sub-block if X is not empty. Params is the other inputs and Out is the
95+
outputs of the sub-block.
96+
)DOC");
97+
}
98+
};
99+
100+
class ConditionalBlockGradOp : public ConditionalOp {
101+
public:
102+
ConditionalBlockGradOp(const std::string &type,
103+
const framework::VariableNameMap &inputs,
104+
const framework::VariableNameMap &outputs,
105+
const framework::AttributeMap &attrs)
106+
: ConditionalOp(type, inputs, outputs, attrs) {}
107+
void Run(const framework::Scope &scope,
108+
const platform::DeviceContext &dev_ctx) const override {
109+
auto xs = this->InputTensors(scope);
110+
bool need_run = std::all_of(
111+
xs.begin(), xs.end(),
112+
[](const framework::LoDTensor *t) { return t->numel() != 0; });
113+
114+
if (need_run) {
115+
auto *scope_var = scope.FindVar(Input("Scope"));
116+
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
117+
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
118+
framework::Scope &cur_scope = *scopes[0];
119+
120+
auto *block = Attr<framework::BlockDescBind *>("block");
121+
framework::Executor exec(dev_ctx);
122+
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
123+
124+
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("Params"),
125+
Outputs(framework::GradVarName("Params")));
126+
127+
AssignLocalGradientToGlobal(dev_ctx, cur_scope, Inputs("X"),
128+
Outputs(framework::GradVarName("X")));
129+
}
130+
}
131+
132+
private:
133+
void AssignLocalGradientToGlobal(
134+
const platform::DeviceContext &dev_ctx, const framework::Scope &cur_scope,
135+
const std::vector<std::string> &p_names,
136+
const std::vector<std::string> &pg_names) const {
137+
for (size_t i = 0; i < p_names.size(); ++i) {
138+
auto out_grad_name = pg_names[i];
139+
auto in_grad_name = framework::GradVarName(p_names[i]);
140+
auto *in_var = cur_scope.FindVar(in_grad_name);
141+
if (in_var == nullptr) {
142+
continue;
143+
}
144+
auto new_in_grad_name = cur_scope.Rename(in_grad_name);
145+
auto assign =
146+
framework::OpRegistry::CreateOp("assign", {{"X", {new_in_grad_name}}},
147+
{{"Out", {out_grad_name}}}, {});
148+
assign->Run(cur_scope, dev_ctx);
149+
cur_scope.Rename(new_in_grad_name, in_grad_name);
150+
}
151+
}
152+
};
153+
154+
class ConditionalBlockGradInferShape : public framework::InferShapeBase {
155+
public:
156+
void operator()(framework::InferShapeContext *context) const override {
157+
PADDLE_ENFORCE(context->HasInputs("X"));
158+
if (context->HasInputs("Params")) {
159+
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("Params")));
160+
context->SetOutputsDim(framework::GradVarName("Params"),
161+
context->GetInputsDim("Params"));
162+
}
163+
PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("X")));
164+
context->SetOutputsDim(framework::GradVarName("X"),
165+
context->GetInputsDim("X"));
166+
}
167+
};
168+
169+
class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
170+
public:
171+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
172+
173+
protected:
174+
std::unique_ptr<framework::OpDescBind> Apply() const override {
175+
auto grad_op = new framework::OpDescBind();
176+
grad_op->SetType("conditional_block_grad");
177+
grad_op->SetInput("X", Input("X"));
178+
grad_op->SetInput("Params", Input("Params"));
179+
grad_op->SetInput("Out", Output("Out"));
180+
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
181+
grad_op->SetInput("Scope", Output("Scope"));
182+
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
183+
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params"));
184+
grad_op->SetBlockAttr("block", *this->grad_block_[0]);
185+
return std::unique_ptr<framework::OpDescBind>(grad_op);
186+
}
187+
};
188+
189+
} // namespace operators
190+
} // namespace paddle
191+
192+
namespace ops = paddle::operators;
193+
REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
194+
ops::ConditionalBlockOpProtoMaker,
195+
ops::ConditionalBlockGradMaker);
196+
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
197+
ops::ConditionalBlockGradInferShape);

python/paddle/v2/framework/framework.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def find_name(var_list, name):
285285
self.desc.check_attrs()
286286
no_kernel_op_set = {
287287
'feed', 'fetch', 'save', 'load', 'recurrent',
288-
'rnn_memory_helper_grad', 'while'
288+
'rnn_memory_helper_grad', 'conditional_block', 'while'
289289
}
290290
if type not in no_kernel_op_set:
291291
self.desc.infer_var_type(self.block.desc)

python/paddle/v2/framework/layers.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,11 @@ def data(name,
226226
stop_gradient=stop_gradient)
227227

228228

229+
def create_tensor(dtype, name=None, main_program=None):
230+
helper = LayerHelper("create_tensor", **locals())
231+
return helper.create_variable(name=helper.name, dtype=dtype)
232+
233+
229234
def _convert_(name):
230235
"""
231236
Formatting.
@@ -451,6 +456,16 @@ def sums(input, main_program=None, startup_program=None):
451456
return out
452457

453458

459+
def assign(input, output, main_program=None):
460+
helper = LayerHelper('assign', **locals())
461+
helper.append_op(
462+
type='scale',
463+
inputs={'X': [input]},
464+
outputs={'Out': [output]},
465+
attrs={'scale': 1.0})
466+
return output
467+
468+
454469
def split_lod_tensor(input,
455470
mask,
456471
level,
@@ -1415,3 +1430,73 @@ def array_length(array, main_program=None):
14151430
helper.append_op(
14161431
type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]})
14171432
return tmp
1433+
1434+
1435+
class ConditionalBlockGuard(BlockGuard):
1436+
def __init__(self, block):
1437+
if not isinstance(block, ConditionalBlock):
1438+
raise TypeError("block should be conditional block")
1439+
super(ConditionalBlockGuard, self).__init__(block.helper.main_program)
1440+
self.block = block
1441+
1442+
def __enter__(self):
1443+
return super(ConditionalBlockGuard, self).__enter__()
1444+
1445+
def __exit__(self, exc_type, exc_val, exc_tb):
1446+
self.block.complete()
1447+
return super(ConditionalBlockGuard, self).__exit__(exc_type, exc_val,
1448+
exc_tb)
1449+
1450+
1451+
class ConditionalBlock(object):
1452+
def __init__(self, inputs, name=None, main_program=None):
1453+
for each_input in inputs:
1454+
if not isinstance(each_input, Variable):
1455+
raise TypeError("Each input should be variable")
1456+
self.inputs = inputs
1457+
self.helper = LayerHelper(
1458+
'conditional_block', name=name, main_program=main_program)
1459+
1460+
def block(self):
1461+
return ConditionalBlockGuard(self)
1462+
1463+
def complete(self):
1464+
inside_block = self.helper.main_program.current_block()
1465+
parent_block = self.helper.main_program.block(inside_block.parent_idx)
1466+
1467+
intermediate = set()
1468+
params = set()
1469+
1470+
for each_op in inside_block.ops:
1471+
assert isinstance(each_op, Operator)
1472+
for iname in each_op.input_names:
1473+
for in_var_name in each_op.input(iname):
1474+
if in_var_name not in intermediate:
1475+
params.add(in_var_name)
1476+
1477+
for oname in each_op.output_names:
1478+
for out_var_name in each_op.output(oname):
1479+
intermediate.add(out_var_name)
1480+
input_set = set([ipt.name for ipt in self.inputs])
1481+
1482+
param_list = [
1483+
parent_block.var(each_name) for each_name in params
1484+
if each_name not in input_set
1485+
]
1486+
1487+
out_list = [
1488+
parent_block.var(var_name) for var_name in parent_block.vars
1489+
if var_name not in intermediate
1490+
]
1491+
1492+
step_scope = parent_block.create_var(
1493+
type=core.VarDesc.VarType.STEP_SCOPES)
1494+
parent_block.append_op(
1495+
type='conditional_block',
1496+
inputs={
1497+
'X': self.inputs,
1498+
'Params': param_list,
1499+
},
1500+
outputs={'Out': out_list,
1501+
'Scope': [step_scope]},
1502+
attrs={'block': inside_block})

0 commit comments

Comments
 (0)