Skip to content

Commit c3b46d1

Browse files
authored
Merge pull request #4573 from Canpio/dev_backward_for_op_desc
Backward for op desc
2 parents cb1baa3 + bd7b669 commit c3b46d1

File tree

8 files changed

+520
-1
lines changed

8 files changed

+520
-1
lines changed

paddle/framework/backward.cc

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
#include "paddle/framework/backward.h"
1616
#include "paddle/operators/net_op.h"
1717

18+
#include <deque>
1819
#include <list>
1920
#include <memory>
2021

22+
#include "paddle/framework/block_desc.h"
2123
#include "paddle/framework/op_registry.h"
2224
#include "paddle/operators/net_op.h"
2325
#include "paddle/operators/recurrent_op.h"
@@ -270,5 +272,145 @@ std::unique_ptr<OperatorBase> Backward(
270272
return BackwardRecursive(forwardOp, no_grad_names, uid);
271273
}
272274

275+
// ==================================== //
276+
277+
static bool AllGradInSet(const std::vector<std::string>& names,
278+
const std::unordered_set<std::string>& set) {
279+
for (const std::string& name : names) {
280+
if (!set.count(GradVarName(name))) {
281+
return false;
282+
}
283+
}
284+
return true;
285+
}
286+
287+
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
288+
const std::unique_ptr<OpDescBind>& op_desc,
289+
std::unordered_set<std::string>& no_grad_vars) {
290+
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
291+
// All input gradients of forwarding operator do not need to calculat.
292+
const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
293+
if (AllGradInSet(inputs, no_grad_vars)) {
294+
return grad_op_descs; // empty vector
295+
}
296+
// All output gradients of forwarding operator do not need to calculate.
297+
const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
298+
if (AllGradInSet(outputs, no_grad_vars)) {
299+
for (const std::string& name : inputs) {
300+
no_grad_vars.insert(GradVarName(name));
301+
}
302+
return grad_op_descs; // empty vector
303+
}
304+
305+
grad_op_descs = OpRegistry::CreateGradOpDescs(*op_desc);
306+
307+
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
308+
for (auto& desc : grad_op_descs) {
309+
for (const std::string& in_name : desc->InputArgumentNames()) {
310+
if (no_grad_vars.count(in_name)) {
311+
std::string prefix = in_name.substr(
312+
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
313+
std::string new_name = prefix + kZeroVarSuffix;
314+
desc->Rename(in_name, new_name);
315+
std::unique_ptr<OpDescBind> fill_zeros_op(new OpDescBind(
316+
"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}));
317+
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
318+
}
319+
}
320+
for (const std::string& out_name : desc->OutputArgumentNames()) {
321+
if (no_grad_vars.count(out_name)) {
322+
desc->Rename(out_name, kEmptyVarName);
323+
}
324+
}
325+
}
326+
327+
for (auto& p : pending_fill_zeros_ops) {
328+
grad_op_descs.insert(grad_op_descs.begin(), std::move(p));
329+
}
330+
return grad_op_descs;
331+
}
332+
333+
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
334+
ProgramDescBind& program_desc, int block_idx,
335+
std::unordered_set<std::string>& no_grad_vars) {
336+
BlockDescBind* cur_block = program_desc.Block(block_idx);
337+
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
338+
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
339+
size_t grad_desc_idx = 0;
340+
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
341+
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
342+
std::vector<std::unique_ptr<OpDescBind>> op_grads =
343+
MakeOpGrad(*it, no_grad_vars);
344+
345+
if ((*it)->Type() == "recurrent") {
346+
PADDLE_ENFORCE_EQ(
347+
op_grads.size(), size_t(1),
348+
"rnn_op's gradient process should contain only one op.");
349+
int step_block_idx = (*it)->GetBlockAttr("stop_block");
350+
auto backward_block_op_descs =
351+
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars);
352+
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
353+
for (auto& ptr : backward_block_op_descs) {
354+
backward_block->ops_.push_back(std::move(ptr));
355+
}
356+
op_grads[0]->SetBlockAttr("step_block", *backward_block);
357+
}
358+
359+
for (const auto& desc : op_grads) {
360+
for (const std::string& out_name : desc->OutputArgumentNames()) {
361+
dup_out_ops[out_name].emplace_back(grad_desc_idx);
362+
}
363+
++grad_desc_idx;
364+
}
365+
std::transform(
366+
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
367+
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
368+
}
369+
// Check whether some variables are written more than once
370+
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
371+
for (const auto& dup : dup_out_ops) {
372+
const std::string& out_name = dup.first;
373+
const std::vector<size_t> dup_op = dup.second;
374+
if (out_name != kEmptyVarName && dup_op.size() > 1) {
375+
std::vector<std::string> sum_op_inputs;
376+
for (size_t i = 0; i < dup_op.size(); ++i) {
377+
std::string new_name = out_name + "@RENAME@" + std::to_string(i);
378+
backward_descs[dup_op[i]]->Rename(out_name, new_name);
379+
sum_op_inputs.emplace_back(new_name);
380+
}
381+
std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
382+
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
383+
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
384+
}
385+
}
386+
pending_sum_ops.sort(
387+
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
388+
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
389+
return a.first > b.first;
390+
});
391+
for (auto& p : pending_sum_ops) {
392+
backward_descs.insert(backward_descs.begin() + p.first + 1,
393+
std::move(p.second));
394+
}
395+
return backward_descs;
396+
}
397+
398+
void AppendBackward(ProgramDescBind& program_desc,
399+
const std::unordered_set<std::string>& no_grad_vars) {
400+
std::unordered_set<std::string> no_grad_var_names;
401+
no_grad_var_names.reserve(no_grad_vars.size() + 1);
402+
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
403+
for (auto& name : no_grad_vars) {
404+
no_grad_var_names.insert(GradVarName(name));
405+
}
406+
const int root_block_idx = 0;
407+
auto backward_op_descs =
408+
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names);
409+
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
410+
for (auto& ptr : backward_op_descs) {
411+
forw_op_descs.push_back(std::move(ptr));
412+
}
413+
}
414+
273415
} // namespace framework
274416
} // namespace paddle

paddle/framework/backward.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
limitations under the License. */
1414

1515
#pragma once
16+
1617
#include <unordered_set>
17-
#include "operator.h"
18+
#include "paddle/framework/operator.h"
19+
#include "paddle/framework/program_desc.h"
20+
1821
namespace paddle {
1922
namespace framework {
2023

@@ -23,5 +26,9 @@ namespace framework {
2326
extern std::unique_ptr<OperatorBase> Backward(
2427
const OperatorBase& forwardOp,
2528
const std::unordered_set<std::string>& no_grad_vars);
29+
30+
void AppendBackward(ProgramDescBind& program_desc,
31+
const std::unordered_set<std::string>& no_grad_vars);
32+
2633
} // namespace framework
2734
} // namespace paddle

0 commit comments

Comments
 (0)