Skip to content

Commit f790b96

Browse files
author
Yang Yang(Tony)
authored
make variable->Grad() a weak_ptr (#11453)
* fix #11416 * make sgd check tape has been backwarded_ * add error message
1 parent a59c3b7 commit f790b96

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

paddle/contrib/tape/function.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ class SGD {
112112
}
113113

114114
void operator()(VariableHandle input) {
115+
PADDLE_ENFORCE(get_global_tape().HasBeenBackwarded(),
116+
"optimization must happen after the backward");
115117
Tape temp_tape;
116118
temp_tape.AddOp("sgd",
117119
{{"Param", {input}},
@@ -120,7 +122,6 @@ class SGD {
120122
{{"ParamOut", {input}}},
121123
{});
122124
temp_tape.Forward();
123-
input->ResetGrad();
124125
}
125126

126127
private:

paddle/contrib/tape/tape.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class Tape {
4747
void Forward();
4848
void Backward(VariableHandle target);
4949

50+
bool HasBeenBackwarded() { return has_been_backwarded_; }
51+
5052
private:
5153
bool has_been_backwarded_ = false;
5254
size_t current_position_ = 0;

paddle/contrib/tape/variable.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ class Variable {
4545
void InitializeVariable();
4646

4747
VariableHandle Grad() {
48-
if (grad_ == nullptr) {
49-
grad_.reset(new Variable(desc_.Name(), true));
48+
if (grad_.expired()) {
49+
VariableHandle new_grad(new Variable(desc_.Name(), true));
50+
grad_ = new_grad;
51+
return new_grad;
52+
} else {
53+
return VariableHandle(grad_);
5054
}
51-
52-
return grad_;
5355
}
5456

55-
void ResetGrad() { grad_ = nullptr; }
56-
5757
// Stochastic Gradient Descent with Momentum
5858
// VariableHandle Momentum ();
5959

@@ -79,7 +79,7 @@ class Variable {
7979
framework::VarDesc desc_;
8080
framework::Variable var_;
8181

82-
VariableHandle grad_;
82+
std::weak_ptr<Variable> grad_;
8383
};
8484
}
8585
}

0 commit comments

Comments
 (0)