Skip to content

Commit ac46018

Browse files
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_bn_eq
2 parents 8a49f7f + 9cf6036 commit ac46018

25 files changed

+380
-106
lines changed

paddle/framework/backward.cc

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,19 @@ static bool AllGradInSet(const std::vector<std::string>& names,
270270
return false;
271271
}
272272
}
273+
if (VLOG_IS_ON(10)) {
274+
std::ostringstream sout;
275+
sout << "All input {";
276+
for (auto& name : names) {
277+
sout << name << ",";
278+
}
279+
sout << "} is in {";
280+
for (auto& name : set) {
281+
sout << name << ",";
282+
}
283+
sout << "}";
284+
VLOG(10) << sout.str();
285+
}
273286
return true;
274287
}
275288

@@ -290,14 +303,12 @@ static void CreateGradVarInBlock(
290303
auto ops = block_desc->AllOps();
291304
for (size_t op_index = grad_op_start_index; op_index < ops.size();
292305
++op_index) {
293-
bool need_infer_shape = false;
294306
std::unordered_set<std::string> new_vars;
295307
ForEachVarName(ops[op_index]->Outputs(),
296308
[&](const std::string& grad_var_name) {
297309
if (block_desc->HasVar(grad_var_name)) {
298310
return false;
299311
}
300-
need_infer_shape = true;
301312
auto var = block_desc->Var(grad_var_name);
302313
new_vars.insert(var->Name());
303314
auto it = param_name_map.find(grad_var_name);
@@ -311,23 +322,21 @@ static void CreateGradVarInBlock(
311322
grad_record.op_idx_ = static_cast<int>(op_index);
312323
return false; /* not break */
313324
});
314-
if (need_infer_shape) {
315-
ops[op_index]->InferVarType(block_desc);
316-
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
317-
if (new_vars.find(arg) == new_vars.end()) {
318-
continue;
319-
}
320-
auto pname = FwdName(arg);
321-
auto* param = block_desc->FindVarRecursive(pname);
322-
auto* grad = block_desc->FindVar(arg);
323-
if (param == nullptr) {
324-
grad->SetDataType(DataType::FP32);
325-
} else {
326-
grad->SetDataType(param->GetDataType());
327-
}
325+
ops[op_index]->InferVarType(block_desc);
326+
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
327+
if (new_vars.find(arg) == new_vars.end()) {
328+
continue;
329+
}
330+
auto pname = FwdName(arg);
331+
auto* param = block_desc->FindVarRecursive(pname);
332+
auto* grad = block_desc->FindVar(arg);
333+
if (param == nullptr) {
334+
grad->SetDataType(DataType::FP32);
335+
} else {
336+
grad->SetDataType(param->GetDataType());
328337
}
329-
ops[op_index]->InferShape(*block_desc);
330338
}
339+
ops[op_index]->InferShape(*block_desc);
331340
}
332341
}
333342

@@ -387,16 +396,18 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
387396
ProgramDescBind& program_desc, int block_idx,
388397
std::unordered_set<std::string>* no_grad_vars,
389398
std::unordered_map<std::string, std::string>* grad_to_var) {
399+
VLOG(5) << "MakeBlockBackward";
390400
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
391401
std::vector<OpDescBind*> op_descs = cur_block->AllOps();
392402
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
393403
size_t grad_desc_idx = 0;
394404
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
395405

396406
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
407+
VLOG(5) << "Making backward " << (*it)->Type() << " op";
397408
std::vector<std::unique_ptr<OpDescBind>> op_grads;
398409

399-
if ((*it)->Type() == "recurrent") {
410+
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
400411
int step_block_idx = (*it)->GetBlockAttr("step_block");
401412
BlockDescBind* backward_block = CreateStepBlock(
402413
program_desc, no_grad_vars, grad_to_var, step_block_idx);
@@ -410,6 +421,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
410421
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
411422
}
412423

424+
if (VLOG_IS_ON(10)) {
425+
std::ostringstream sout;
426+
sout << "Made ";
427+
for (auto& op_grad : op_grads) {
428+
sout << op_grad->Type() << " ";
429+
}
430+
VLOG(10) << sout.str();
431+
}
432+
413433
for (const auto& desc : op_grads) {
414434
for (const std::string& out_name : desc->OutputArgumentNames()) {
415435
if (out_name.find("@GRAD") == std::string::npos) {
@@ -425,23 +445,31 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
425445
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
426446
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
427447
}
448+
449+
VLOG(5) << "Appending Sums";
428450
// Check whether some variables are written more than once
429451
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
430452
for (const auto& dup : dup_out_ops) {
431453
const std::string& out_name = dup.first;
432454
const std::vector<size_t> dup_op = dup.second;
433455
if (out_name != kEmptyVarName && dup_op.size() > 1) {
434456
std::vector<std::string> sum_op_inputs;
457+
std::string next_g_name = out_name;
435458
for (size_t i = 0; i < dup_op.size(); ++i) {
459+
VLOG(10) << backward_descs[dup_op[i]]->Type() << " has " << out_name
460+
<< " duplicated";
436461
std::string new_name = out_name + "@RENAME@" + std::to_string(i);
437-
backward_descs[dup_op[i]]->Rename(out_name, new_name);
462+
backward_descs[dup_op[i]]->RenameOutput(out_name, new_name);
463+
backward_descs[dup_op[i]]->RenameInput(out_name, next_g_name);
438464
sum_op_inputs.emplace_back(new_name);
465+
next_g_name = sum_op_inputs.back();
439466
}
440467
std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
441468
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
442469
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
443470
}
444471
}
472+
445473
pending_sum_ops.sort(
446474
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
447475
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
@@ -452,6 +480,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
452480
std::move(p.second));
453481
}
454482

483+
VLOG(5) << "MakeBlockBackward Finished";
484+
455485
return backward_descs;
456486
}
457487

paddle/framework/data_type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ inline DataType ToDataType(std::type_index type) {
2929
return DataType::INT32;
3030
} else if (typeid(int64_t).hash_code() == type.hash_code()) {
3131
return DataType::INT64;
32+
} else if (typeid(bool).hash_code() == type.hash_code()) {
33+
return DataType::BOOL;
3234
} else {
3335
PADDLE_THROW("Not supported");
3436
}

paddle/framework/ddim.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ void make_ddim(DDim& ddim, const int64_t* dims, int n) {
6060
ddim = make_dim<9>(dims);
6161
break;
6262
default:
63-
throw std::invalid_argument(
64-
"Dynamic dimensions must have between [1, 9] dimensions.");
63+
PADDLE_THROW("Dynamic dimensions must have between [1, 9] dimensions.");
6564
}
6665
}
6766

paddle/framework/executor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
120120

121121
for (auto& op_desc : block.AllOps()) {
122122
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
123+
VLOG(10) << op->DebugString();
123124
op->Run(*local_scope, *device);
124125
}
125126
if (create_local_scope) {

paddle/framework/op_desc.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,23 @@ void OpDescBind::Rename(const std::string &old_name,
235235
need_update_ = true;
236236
}
237237

238+
void OpDescBind::RenameOutput(const std::string &old_name,
239+
const std::string &new_name) {
240+
for (auto &output : outputs_) {
241+
std::replace(output.second.begin(), output.second.end(), old_name,
242+
new_name);
243+
}
244+
need_update_ = true;
245+
}
246+
247+
void OpDescBind::RenameInput(const std::string &old_name,
248+
const std::string &new_name) {
249+
for (auto &input : inputs_) {
250+
std::replace(input.second.begin(), input.second.end(), old_name, new_name);
251+
}
252+
need_update_ = true;
253+
}
254+
238255
struct SetAttrDescVisitor : public boost::static_visitor<void> {
239256
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
240257
mutable OpDesc::Attr *attr_;
@@ -448,7 +465,12 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
448465
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
449466
auto var = block_.FindVarRecursive(name);
450467
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
451-
return framework::make_ddim(var->Shape());
468+
try {
469+
return framework::make_ddim(var->Shape());
470+
} catch (...) {
471+
VLOG(5) << "GetDim of variable " << name << " error";
472+
std::rethrow_exception(std::current_exception());
473+
}
452474
}
453475

454476
void CompileTimeInferShapeContext::SetDim(const std::string &name,

paddle/framework/op_desc.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class OpDescBind {
7373

7474
void Rename(const std::string &old_name, const std::string &new_name);
7575

76+
void RenameOutput(const std::string &old_name, const std::string &new_name);
77+
78+
void RenameInput(const std::string &old_name, const std::string &new_name);
79+
7680
// Only be used in C++
7781
const AttributeMap &GetAttrMap() const;
7882

paddle/framework/operator.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -403,19 +403,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
403403

404404
void OperatorWithKernel::Run(const Scope& scope,
405405
const platform::DeviceContext& dev_ctx) const {
406-
if (VLOG_IS_ON(1)) {
407-
auto inputs = this->InputVars();
408-
auto outputs = this->OutputVars(true);
409-
std::ostringstream sout;
410-
sout << "Run operator " << this->Type() << " From [";
411-
std::ostream_iterator<std::string> out_it(sout, ",");
412-
std::copy(inputs.begin(), inputs.end(), out_it);
413-
sout << "] to [";
414-
std::copy(outputs.begin(), outputs.end(), out_it);
415-
sout << "]";
416-
VLOG(1) << sout.str();
417-
}
418-
419406
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
420407
this->InferShape(&infer_shape_ctx);
421408

paddle/framework/scope.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@ Scope& Scope::NewScope() const {
3838
Variable* Scope::Var(const std::string& name) {
3939
auto iter = vars_.find(name);
4040
if (iter != vars_.end()) {
41+
VLOG(3) << "Get existing variable " << name;
4142
return iter->second;
4243
}
4344
Variable* v = new Variable();
4445
vars_[name] = v;
45-
VLOG(3) << "Create variable " << name << " on scope";
46+
VLOG(3) << "Create variable " << name;
4647
v->name_ = &(vars_.find(name)->first);
4748
return v;
4849
}

paddle/framework/shape_inference.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,17 @@ class InferShapeContext {
5353

5454
virtual bool IsRuntime() const = 0;
5555

56+
// Note: In while op, we need this to be public
57+
void SetDims(const std::vector<std::string> &names,
58+
const std::vector<framework::DDim> &dims);
59+
5660
protected:
5761
virtual framework::DDim GetDim(const std::string &name) const = 0;
5862
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
5963

6064
std::vector<framework::DDim> GetDims(
6165
const std::vector<std::string> &names) const;
6266

63-
void SetDims(const std::vector<std::string> &names,
64-
const std::vector<framework::DDim> &dims);
65-
6667
std::vector<VarDesc::VarType> GetVarTypes(
6768
const std::vector<std::string> &names) const;
6869

paddle/operators/array_operator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class ArrayOp : public framework::OperatorBase {
4242
} else {
4343
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
4444
}
45+
VLOG(10) << " Offset = " << offset;
4546
return offset;
4647
}
4748
};

0 commit comments

Comments
 (0)