Skip to content

Commit c11afdb

Browse files
authored
Merge pull request #15516 from panyx0718/imperative3
imperative supports multi grad ops
2 parents b919190 + 42e61af commit c11afdb

File tree

4 files changed

+118
-85
lines changed

4 files changed

+118
-85
lines changed

paddle/fluid/imperative/layer.cc

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -204,59 +204,68 @@ framework::LoDTensor& VarBase::GradValue() {
204204
}
205205

206206
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
207-
if (!grad_op_desc_ && backward_id_ <= 0) {
207+
if (grad_op_descs_.empty() && backward_id_ <= 0) {
208208
LOG(WARNING) << "op with no grad: " << op_desc_->Type();
209209
return {};
210210
}
211211

212-
std::map<std::string, std::vector<framework::Variable*>> grad_outputs;
212+
std::vector<framework::VariableValueMap> grad_outputs;
213213
if (backward_id_ > 0) {
214214
VLOG(3) << "py_layer_grad";
215-
grad_outputs[framework::GradVarName(PyLayer::kFwdOut)] = PyLayer::ApplyGrad(
216-
backward_id_,
217-
grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)]);
215+
grad_outputs.resize(1);
216+
grad_outputs[0][framework::GradVarName(PyLayer::kFwdOut)] =
217+
PyLayer::ApplyGrad(
218+
backward_id_,
219+
grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)]);
218220
} else {
219-
VLOG(3) << "op grad " << grad_op_desc_->Type();
220-
for (auto it : grad_output_vars_) {
221-
auto& outputs = grad_outputs[it.first];
222-
for (size_t i = 0; i < it.second.size(); ++i) {
223-
// Allocate a new variable
224-
Variable* tmp_var = new framework::Variable();
225-
tmp_var->GetMutable<framework::LoDTensor>();
226-
outputs.push_back(tmp_var);
221+
grad_outputs.resize(grad_op_descs_.size());
222+
for (size_t k = 0; k < grad_op_descs_.size(); ++k) {
223+
framework::OpDesc* grad_op_desc = grad_op_descs_[k];
224+
VLOG(3) << "op grad " << grad_op_desc->Type();
225+
for (auto it : grad_output_vars_[k]) {
226+
auto& outputs = grad_outputs[k][it.first];
227+
for (size_t i = 0; i < it.second.size(); ++i) {
228+
// Allocate a new variable
229+
Variable* tmp_var = new framework::Variable();
230+
tmp_var->GetMutable<framework::LoDTensor>();
231+
outputs.push_back(tmp_var);
232+
}
227233
}
228-
}
229234

230-
framework::RuntimeContext ctx(grad_input_vars_, grad_outputs);
235+
framework::RuntimeContext ctx(grad_input_vars_[k], grad_outputs[k]);
231236

232-
// No need to do compile time infer shape here.
233-
// grad_op_desc_->InferShape(*block_);
234-
grad_op_desc_->InferVarType(block_);
237+
// No need to do compile time infer shape here.
238+
// grad_op_desc_->InferShape(*block_);
239+
grad_op_desc->InferVarType(block_);
235240

236-
std::unique_ptr<framework::OperatorBase> opbase =
237-
framework::OpRegistry::CreateOp(*grad_op_desc_);
238-
framework::OperatorWithKernel* op_kernel =
239-
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
240-
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
241+
std::unique_ptr<framework::OperatorBase> opbase =
242+
framework::OpRegistry::CreateOp(*grad_op_desc);
243+
framework::OperatorWithKernel* op_kernel =
244+
dynamic_cast<framework::OperatorWithKernel*>(opbase.get());
245+
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
241246

242-
framework::Scope scope;
243-
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
244-
p.op.RuntimeInferShape(scope, place_, ctx);
245-
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
247+
framework::Scope scope;
248+
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_);
249+
p.op.RuntimeInferShape(scope, place_, ctx);
250+
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
251+
}
246252
}
247253

248-
for (auto it : grad_output_vars_) {
249-
auto& outputs = grad_outputs[it.first];
250-
auto& origin_outputs = it.second;
251-
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
252-
253-
for (size_t i = 0; i < outputs.size(); ++i) {
254-
framework::Variable* grad = outputs[i];
255-
framework::Variable* orig_grad = origin_outputs[i];
256-
AddTo(grad, orig_grad, place_);
257-
delete grad;
254+
for (size_t k = 0; k < grad_output_vars_.size(); ++k) {
255+
for (auto it : grad_output_vars_[k]) {
256+
auto& outputs = grad_outputs[k][it.first];
257+
auto& origin_outputs = it.second;
258+
PADDLE_ENFORCE_EQ(outputs.size(), origin_outputs.size());
259+
260+
for (size_t i = 0; i < outputs.size(); ++i) {
261+
framework::Variable* grad = outputs[i];
262+
framework::Variable* orig_grad = origin_outputs[i];
263+
AddTo(grad, orig_grad, place_);
264+
delete grad;
265+
}
258266
}
259267
}
268+
260269
return input_vars_;
261270
}
262271

paddle/fluid/imperative/layer.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,13 @@ class OpBase {
184184
OpBase()
185185
: op_desc_(nullptr),
186186
forward_id_(-1),
187-
grad_op_desc_(nullptr),
188187
backward_id_(-1),
189188
place_(platform::CPUPlace()) {}
190189

191190
virtual ~OpBase() {
192-
if (grad_op_desc_) delete grad_op_desc_;
191+
for (framework::OpDesc* desc : grad_op_descs_) {
192+
delete desc;
193+
}
193194
}
194195

195196
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
@@ -198,9 +199,11 @@ class OpBase {
198199
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
199200
framework::OpDesc* op_desc_;
200201
int forward_id_;
201-
// When has backward, one of `grad_op_desc_` or `backward_id_` is set,
202+
203+
// When has backward, one of `grad_op_descs_` or `backward_id_` is set,
202204
// not both.
203-
framework::OpDesc* grad_op_desc_;
205+
// Note: each fwd op corresponds to a vector of bwd ops.
206+
std::vector<framework::OpDesc*> grad_op_descs_;
204207
int backward_id_;
205208

206209
platform::Place place_;
@@ -210,8 +213,11 @@ class OpBase {
210213
OpBasePtrMap pre_ops_;
211214
std::map<std::string, std::vector<int>> pre_ops_out_idx_;
212215

213-
framework::VariableValueMap grad_input_vars_;
214-
framework::VariableValueMap grad_output_vars_;
216+
// Inputs to a vector of bwd ops.
217+
std::vector<framework::VariableValueMap> grad_input_vars_;
218+
// Outputs to a vector of bwd ops.
219+
std::vector<framework::VariableValueMap> grad_output_vars_;
220+
215221
framework::BlockDesc* block_;
216222
};
217223

paddle/fluid/imperative/tracer.cc

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ namespace imperative {
2424
void CreateGradOp(const framework::OpDesc& op_desc,
2525
const std::unordered_set<std::string>& no_grad_set,
2626
const std::vector<framework::BlockDesc*>& grad_sub_block,
27-
framework::OpDesc** grad_op_desc,
27+
std::vector<framework::OpDesc*>* grad_op_descs,
2828
std::unordered_map<std::string, std::string>* grad_to_var) {
29-
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
29+
PADDLE_ENFORCE(grad_op_descs->empty());
30+
std::vector<std::unique_ptr<framework::OpDesc>> descs =
3031
framework::OpInfoMap::Instance()
3132
.Get(op_desc.Type())
3233
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
33-
PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now.");
34-
// TODO(panyx0718): Leak?
35-
*grad_op_desc = grad_op_descs[0].release();
34+
for (auto& desc : descs) {
35+
grad_op_descs->emplace_back(desc.release());
36+
}
3637
}
3738

3839
void InitVar(framework::Variable* var, framework::Variable* grad_var,
@@ -138,49 +139,52 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
138139
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
139140

140141
if (!stop_gradient) {
141-
framework::OpDesc* grad_op_desc;
142-
// TODO(panyx): Is this leaked?
143142
std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
144143
new std::unordered_map<std::string, std::string>());
145-
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var.get());
146-
op->grad_op_desc_ = grad_op_desc;
147-
148-
for (auto it : grad_op_desc->Inputs()) {
149-
auto& grad_in_vars = op->grad_input_vars_[it.first];
150-
for (const std::string& grad_invar : it.second) {
151-
block->FindRecursiveOrCreateVar(grad_invar);
152-
auto var_it = grad_to_var->find(grad_invar);
153-
if (var_it == grad_to_var->end()) {
154-
auto fwd_var_it = vars.find(grad_invar);
155-
PADDLE_ENFORCE(fwd_var_it != vars.end());
156-
// Forward inputs or outputs.
157-
grad_in_vars.push_back(fwd_var_it->second->var_);
158-
} else {
144+
CreateGradOp(*op_desc, {}, {block}, &op->grad_op_descs_, grad_to_var.get());
145+
146+
op->grad_input_vars_.resize(op->grad_op_descs_.size());
147+
op->grad_output_vars_.resize(op->grad_op_descs_.size());
148+
for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) {
149+
framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
150+
for (auto it : grad_op_desc->Inputs()) {
151+
auto& grad_in_vars = op->grad_input_vars_[i][it.first];
152+
for (const std::string& grad_invar : it.second) {
153+
block->FindRecursiveOrCreateVar(grad_invar);
154+
auto var_it = grad_to_var->find(grad_invar);
155+
if (var_it == grad_to_var->end()) {
156+
auto fwd_var_it = vars.find(grad_invar);
157+
PADDLE_ENFORCE(fwd_var_it != vars.end());
158+
// Forward inputs or outputs.
159+
grad_in_vars.push_back(fwd_var_it->second->var_);
160+
} else {
161+
VarBase* var = vars[var_it->second];
162+
if (!var->grads_->var_->IsInitialized()) {
163+
InitVar(var->var_, var->grads_->var_,
164+
prepared_op.GetDeviceContext());
165+
}
166+
// Douts.
167+
grad_in_vars.push_back(var->grads_->var_);
168+
}
169+
}
170+
}
171+
172+
for (auto it : grad_op_desc->Outputs()) {
173+
auto& grad_out_vars = op->grad_output_vars_[i][it.first];
174+
for (const std::string& grad_outvar : it.second) {
175+
block->FindRecursiveOrCreateVar(grad_outvar);
176+
auto var_it = grad_to_var->find(grad_outvar);
177+
PADDLE_ENFORCE(var_it != grad_to_var->end(),
178+
"Could not found the grad op output var, should this "
179+
"operator %s's stop gradient be True",
180+
op_desc->Type());
159181
VarBase* var = vars[var_it->second];
160182
if (!var->grads_->var_->IsInitialized()) {
161183
InitVar(var->var_, var->grads_->var_,
162184
prepared_op.GetDeviceContext());
163185
}
164-
// Douts.
165-
grad_in_vars.push_back(var->grads_->var_);
166-
}
167-
}
168-
}
169-
170-
for (auto it : grad_op_desc->Outputs()) {
171-
auto& grad_out_vars = op->grad_output_vars_[it.first];
172-
for (const std::string& grad_outvar : it.second) {
173-
block->FindRecursiveOrCreateVar(grad_outvar);
174-
auto var_it = grad_to_var->find(grad_outvar);
175-
PADDLE_ENFORCE(var_it != grad_to_var->end(),
176-
"Could not found the grad op output var, should this "
177-
"operator %s's stop gradient be True",
178-
op_desc->Type());
179-
VarBase* var = vars[var_it->second];
180-
if (!var->grads_->var_->IsInitialized()) {
181-
InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext());
186+
grad_out_vars.push_back(var->grads_->var_);
182187
}
183-
grad_out_vars.push_back(var->grads_->var_);
184188
}
185189
}
186190
}
@@ -209,10 +213,12 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
209213
out->TrackPreOp(op, PyLayer::kFwdOut, i, stop_gradient);
210214
}
211215
if (!stop_gradient) {
216+
op->grad_input_vars_.resize(1);
217+
op->grad_output_vars_.resize(1);
212218
auto& grad_input_vars =
213-
op->grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)];
219+
op->grad_input_vars_[0][framework::GradVarName(PyLayer::kFwdInp)];
214220
auto& grad_output_vars =
215-
op->grad_output_vars_[framework::GradVarName(PyLayer::kFwdOut)];
221+
op->grad_output_vars_[0][framework::GradVarName(PyLayer::kFwdOut)];
216222

217223
for (const VarBase* inp : inputs) {
218224
grad_input_vars.push_back(inp->var_);

python/paddle/fluid/tests/unittests/test_imperative.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,18 @@ def forward(self, inputs):
6767

6868

6969
class TestImperative(unittest.TestCase):
70+
def test_sum_op(self):
71+
x = np.ones([2, 2], np.float32)
72+
with fluid.imperative.guard():
73+
inputs = []
74+
for _ in range(10):
75+
inputs.append(fluid.imperative.base.to_variable(x))
76+
ret = fluid.layers.sums(inputs)
77+
loss = fluid.layers.reduce_sum(ret)
78+
loss._backward()
79+
self.assertTrue(np.allclose(ret._numpy(), x * 10))
80+
self.assertTrue(np.allclose(inputs[0]._gradient(), x))
81+
7082
def test_layer(self):
7183
with fluid.imperative.guard():
7284
cl = core.Layer()

0 commit comments

Comments
 (0)