File tree Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Expand file tree Collapse file tree 3 files changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -112,6 +112,8 @@ class SGD {
112
112
}
113
113
114
114
void operator ()(VariableHandle input) {
115
+ PADDLE_ENFORCE (get_global_tape ().HasBeenBackwarded (),
116
+ " optimization must happen after the backward" );
115
117
Tape temp_tape;
116
118
temp_tape.AddOp (" sgd" ,
117
119
{{" Param" , {input}},
@@ -120,7 +122,6 @@ class SGD {
120
122
{{" ParamOut" , {input}}},
121
123
{});
122
124
temp_tape.Forward ();
123
- input->ResetGrad ();
124
125
}
125
126
126
127
private:
Original file line number Diff line number Diff line change @@ -47,6 +47,8 @@ class Tape {
47
47
void Forward ();
48
48
void Backward (VariableHandle target);
49
49
50
+ bool HasBeenBackwarded () { return has_been_backwarded_; }
51
+
50
52
private:
51
53
bool has_been_backwarded_ = false ;
52
54
size_t current_position_ = 0 ;
Original file line number Diff line number Diff line change @@ -45,15 +45,15 @@ class Variable {
45
45
void InitializeVariable ();
46
46
47
47
VariableHandle Grad () {
48
- if (grad_ == nullptr ) {
49
- grad_.reset (new Variable (desc_.Name (), true ));
48
+ if (grad_.expired ()) {
49
+ VariableHandle new_grad (new Variable (desc_.Name (), true ));
50
+ grad_ = new_grad;
51
+ return new_grad;
52
+ } else {
53
+ return VariableHandle (grad_);
50
54
}
51
-
52
- return grad_;
53
55
}
54
56
55
- void ResetGrad () { grad_ = nullptr ; }
56
-
57
57
// Stochastic Gradient Descent with Momentum
58
58
// VariableHandle Momentum ();
59
59
@@ -79,7 +79,7 @@ class Variable {
79
79
framework::VarDesc desc_;
80
80
framework::Variable var_;
81
81
82
- VariableHandle grad_;
82
+ std::weak_ptr<Variable> grad_;
83
83
};
84
84
}
85
85
}
You can’t perform that action at this time.
0 commit comments