Skip to content

Commit edb22c2

Browse files
reyoungYang Yang(Tony)
authored andcommitted
Add Scope::Rename (#5534)
it is useful in gradient phase of an operator with block
1 parent 2378679 commit edb22c2

File tree

3 files changed

+34
-17
lines changed

3 files changed

+34
-17
lines changed

paddle/framework/scope.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,23 @@ void Scope::DeleteScope(Scope* scope) {
9898
delete scope;
9999
}
100100

101+
void Scope::Rename(const std::string& origin_name,
102+
const std::string& new_name) const {
103+
auto origin_it = vars_.find(origin_name);
104+
PADDLE_ENFORCE(origin_it != vars_.end(),
105+
"Cannot find original variable with name %s", origin_name);
106+
auto new_it = vars_.find(new_name);
107+
PADDLE_ENFORCE(new_it == vars_.end(),
108+
"The variable with name %s is already in the scope", new_name);
109+
vars_[new_name] = origin_it->second;
110+
vars_.erase(origin_it);
111+
}
112+
113+
std::string Scope::Rename(const std::string& origin_name) const {
114+
auto var_name = string::Sprintf("%p.%d", this, vars_.size());
115+
Rename(origin_name, var_name);
116+
return var_name;
117+
}
118+
101119
} // namespace framework
102120
} // namespace paddle

paddle/framework/scope.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,18 @@ class Scope {
6868
// enumerate all the variables current contains.
6969
std::vector<std::string> GetAllNames(bool recursive = false) const;
7070

71+
// Rename variable to a new name
72+
void Rename(const std::string& origin_name,
73+
const std::string& new_name) const;
74+
75+
// Rename variable to a new name and return the new name
76+
std::string Rename(const std::string& origin_name) const;
77+
7178
private:
7279
// Call Scope::NewScope for a sub-scope.
7380
explicit Scope(Scope const* parent) : parent_(parent) {}
7481

75-
std::unordered_map<std::string, Variable*> vars_;
82+
mutable std::unordered_map<std::string, Variable*> vars_;
7683
mutable std::list<Scope*> kids_;
7784
Scope const* parent_{nullptr};
7885

paddle/operators/recurrent_op.cc

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,8 @@ class RecurrentGradOp : public RecurrentBase {
387387
auto &p_names = Inputs(kParameters);
388388
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
389389

390-
for (size_t prog_id = 0; prog_id < pg_names.size(); ++prog_id) {
391-
auto inside_grad_name = framework::GradVarName(p_names[prog_id]);
390+
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
391+
auto inside_grad_name = framework::GradVarName(p_names[param_id]);
392392

393393
// If does not compute gradient of that variable inside rnn, just
394394
// continue
@@ -406,27 +406,19 @@ class RecurrentGradOp : public RecurrentBase {
406406
attrs["value"] = 0.0f;
407407

408408
auto zero_op = framework::OpRegistry::CreateOp(
409-
"fill_constant", {}, {{"Out", {pg_names[prog_id]}}}, attrs);
409+
"fill_constant", {}, {{"Out", {pg_names[param_id]}}}, attrs);
410410
zero_op->Run(scope, dev_ctx);
411411
}
412412

413+
auto new_inside_name = cur_scope.Rename(inside_grad_name);
413414
// sum gradient
414-
auto *outside_var = scope.FindVar(pg_names[prog_id]);
415-
PADDLE_ENFORCE(outside_var != nullptr);
416-
auto &outside_tensor =
417-
*outside_var->GetMutable<framework::LoDTensor>();
418-
419-
std::string result_var_name;
420-
auto *local_result_var = cur_scope.Var(&result_var_name);
421-
auto &local_result_tensor =
422-
*local_result_var->GetMutable<framework::LoDTensor>();
423-
424-
local_result_tensor.ShareDataWith(outside_tensor);
425415

426416
auto sum_op = framework::OpRegistry::CreateOp(
427-
"sum", {{"X", {result_var_name, inside_grad_name}}},
428-
{{"Out", {result_var_name}}}, {});
417+
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
418+
{{"Out", {pg_names[param_id]}}}, {});
429419
sum_op->Run(cur_scope, dev_ctx);
420+
421+
cur_scope.Rename(new_inside_name, inside_grad_name);
430422
}
431423
}
432424
VLOG(5) << "Accumulate Parameter finished ";

0 commit comments

Comments
 (0)