Skip to content

Commit 187cffd

Browse files
authored
Merge pull request #15928 from velconia/imperative_backward_hooks
Imperative backward hooks
2 parents 1616c32 + e5f3435 commit 187cffd

File tree

18 files changed

+261
-87
lines changed

18 files changed

+261
-87
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/framework/block_desc.h"
16+
1617
#include <queue>
18+
#include <unordered_set>
19+
#include <utility>
20+
1721
#include "paddle/fluid/framework/operator.h"
1822
#include "paddle/fluid/framework/program_desc.h"
1923

@@ -155,6 +159,16 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
155159
ops_.erase(ops_.begin() + s, ops_.begin() + e);
156160
}
157161

162+
void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) {
163+
// TODO(minqiyang): make this faster
164+
for (auto it = ops_.begin(); it != ops_.end(); ++it) {
165+
if (it->get() == op_desc) {
166+
ops_.erase(it);
167+
break;
168+
}
169+
}
170+
}
171+
158172
std::vector<OpDesc *> BlockDesc::AllOps() const {
159173
std::vector<OpDesc *> res;
160174
for (const auto &op : ops_) {
@@ -163,20 +177,6 @@ std::vector<OpDesc *> BlockDesc::AllOps() const {
163177
return res;
164178
}
165179

166-
void BlockDesc::Clear() {
167-
// clear all ops
168-
ops_.clear();
169-
170-
// clear all vars which are not persistable
171-
for (auto it = vars_.begin(); it != vars_.end();) {
172-
if (it->second->Persistable()) {
173-
++it;
174-
} else {
175-
vars_.erase(it++);
176-
}
177-
}
178-
}
179-
180180
void BlockDesc::Flush() {
181181
for (auto &op_desc : ops_) {
182182
op_desc->Flush();

paddle/fluid/framework/block_desc.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ class BlockDesc {
9393
*/
9494
void RemoveOp(size_t s, size_t e);
9595

96+
void RemoveOpInternal(const OpDesc *op_desc);
97+
9698
void RemoveVar(const std::string &name) { vars_.erase(name); }
9799

98100
std::vector<OpDesc *> AllOps() const;
99101

100-
void Clear();
101-
102102
size_t OpSize() const { return ops_.size(); }
103103

104104
OpDesc *Op(int idx) const { return ops_.at(idx).get(); }

paddle/fluid/framework/python_headers.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,11 @@ limitations under the License. */
2424

2525
#pragma pop_macro("_XOPEN_SOURCE")
2626
#pragma pop_macro("_POSIX_C_SOURCE")
27+
28+
#if !defined(PYBIND11_HIDDEN)
29+
#ifdef _WIN32
30+
#define PYBIND11_HIDDEN __declspec(dllexport)
31+
#else
32+
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
33+
#endif
34+
#endif

paddle/fluid/imperative/layer.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <limits>
1919
#include <map>
2020
#include <random>
21+
#include <unordered_set>
2122
#include <utility>
2223

2324
#include "paddle/fluid/framework/lod_tensor.h"
@@ -139,6 +140,8 @@ class Autograd {
139140
}
140141
}
141142
}
143+
144+
ready_op->InvokeBackwardHooks();
142145
}
143146
}
144147

@@ -156,8 +159,10 @@ class Autograd {
156159
for (auto it : candidate->pre_ops_) {
157160
for (OpBase* pre_op : it.second) {
158161
if (!pre_op) continue;
159-
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- "
160-
<< it.first << " <---- " << pre_op->op_desc_->Type();
162+
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " trace id "
163+
<< candidate->trace_id_ << " <---- " << it.first << " <---- "
164+
<< pre_op->op_desc_->Type() << " trace id "
165+
<< pre_op->trace_id_;
161166
if (visited.find(pre_op) == visited.end()) {
162167
visited.insert(pre_op);
163168
queue.push_back(pre_op);
@@ -211,6 +216,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
211216
return {};
212217
}
213218

219+
VLOG(3) << "apply op grad: " << op_desc_->Type();
214220
std::vector<framework::VariableValueMap> grad_outputs;
215221
if (backward_id_ > 0) {
216222
VLOG(3) << "py_layer_grad";
@@ -272,6 +278,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
272278
return input_vars_;
273279
}
274280

281+
void OpBase::InvokeBackwardHooks() {
282+
VLOG(3) << "call backward hooks, hooks num: " << backward_hooks_.size();
283+
284+
// call backward hooks
285+
for (py::object& callable : backward_hooks_) {
286+
callable(this);
287+
}
288+
}
289+
290+
void OpBase::RegisterBackwardHooks(const py::object& callable) {
291+
VLOG(3) << "Register backward hooks " << trace_id_;
292+
293+
// TODO(minqiyang): check the callable format
294+
backward_hooks_.push_back(callable);
295+
}
296+
275297
void VarBase::RunBackward() {
276298
if (!pre_op_) return;
277299

paddle/fluid/imperative/layer.h

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,22 +123,32 @@ class VarBase {
123123

124124
private:
125125
VarBase(framework::Variable* var, VarBase* grad, bool stop_gradient)
126-
: var_desc_(nullptr),
126+
: name_(),
127+
var_desc_(nullptr),
127128
var_(var),
128129
grads_(grad),
130+
block_(nullptr),
131+
persistable_(false),
129132
stop_gradient_(stop_gradient),
130133
pre_op_(nullptr),
134+
pre_op_out_name_(),
131135
pre_op_out_idx_(-1) {}
132136

133137
public:
134138
virtual ~VarBase() {
139+
// TODO(minqiyang): remove var desc from block desc
135140
if (var_) {
136141
delete var_;
142+
var_ = nullptr;
137143
}
138144

139145
if (grads_) {
140146
delete grads_;
147+
grads_ = nullptr;
141148
}
149+
150+
pre_op_ = nullptr;
151+
pre_op_out_idx_ = -1;
142152
}
143153

144154
inline OpBase* PreOp() const { return pre_op_; }
@@ -151,6 +161,14 @@ class VarBase {
151161

152162
void RunBackward();
153163

164+
inline void ResetPreOp(OpBase* op) {
165+
if (op == pre_op_) {
166+
// clear pre_op info when op equals to var's pre_op
167+
pre_op_ = nullptr;
168+
pre_op_out_idx_ = -1;
169+
}
170+
}
171+
154172
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
155173
int pre_op_out_idx, bool pre_op_stop_gradient) {
156174
pre_op_ = pre_op;
@@ -184,11 +202,15 @@ class VarBase {
184202
return string::Sprintf("%s@IGrad", var_desc_->Name());
185203
}
186204

205+
std::string name_;
187206
framework::VarDesc* var_desc_;
188207

189208
framework::Variable* var_;
190209
VarBase* grads_;
191210

211+
framework::BlockDesc* block_;
212+
bool persistable_;
213+
192214
private:
193215
bool stop_gradient_;
194216
OpBase* pre_op_;
@@ -199,22 +221,38 @@ class VarBase {
199221
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
200222
* gradient. This object should be managed totally by Python intepreter.
201223
*/
202-
class OpBase {
224+
class PYBIND11_HIDDEN OpBase {
203225
public:
204226
OpBase()
205227
: op_desc_(nullptr),
206228
forward_id_(-1),
207229
backward_id_(-1),
208-
place_(platform::CPUPlace()) {}
230+
trace_id_(-1),
231+
place_(platform::CPUPlace()),
232+
backward_hooks_() {}
209233

210234
virtual ~OpBase() {
235+
// TODO(minqiyang): remove op_desc from block_desc in tracer
236+
//
237+
// reset all output vars' pre op
238+
for (auto iter : output_vars_) {
239+
for (VarBase* var : iter.second) {
240+
var->ResetPreOp(this);
241+
}
242+
}
243+
244+
// release resource
211245
for (framework::OpDesc* desc : grad_op_descs_) {
212246
delete desc;
213247
}
214248
}
215249

216250
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
217251

252+
void RegisterBackwardHooks(const py::object& callable);
253+
254+
void InvokeBackwardHooks();
255+
218256
// One of `op_desc_` or `forward_id_` is set, not both.
219257
// For pure python PyLayer, use `forward_id_`, otherwise, use op_desc_.
220258
framework::OpDesc* op_desc_;
@@ -225,6 +263,7 @@ class OpBase {
225263
// Note: each fwd op corresponds to a vector of bwd ops.
226264
std::vector<framework::OpDesc*> grad_op_descs_;
227265
int backward_id_;
266+
int trace_id_;
228267

229268
platform::Place place_;
230269

@@ -239,6 +278,8 @@ class OpBase {
239278
std::vector<framework::VariableValueMap> grad_output_vars_;
240279

241280
framework::BlockDesc* block_;
281+
282+
std::vector<py::object> backward_hooks_;
242283
};
243284

244285
class Layer {

paddle/fluid/imperative/tracer.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515
#include "paddle/fluid/imperative/tracer.h"
1616

17+
#include <memory>
1718
#include <set>
19+
#include <unordered_map>
20+
#include <unordered_set>
1821

1922
#include "paddle/fluid/operators/math/math_function.h"
2023
#include "paddle/fluid/platform/device_context.h"
@@ -110,7 +113,8 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
110113
std::map<std::string, VarBase*> vars;
111114

112115
framework::OpDesc* op_desc = op->op_desc_;
113-
VLOG(3) << "tracer tracing " << op_desc->Type();
116+
VLOG(3) << "tracer tracing " << op_desc->Type() << " trace id "
117+
<< op->trace_id_;
114118
op_desc->InferShape(*block);
115119
op_desc->InferVarType(block);
116120

@@ -133,11 +137,13 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
133137
if (inp->PreOp() && !inp->IsStopGradient()) {
134138
op->pre_ops_[it.first].push_back(inp->PreOp());
135139
op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
140+
VLOG(3) << "add pre op " << inp->PreOp()->op_desc_->Type();
136141
} else {
137142
op->pre_ops_[it.first].push_back(nullptr);
138143
}
139144
VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
140-
<< inp->var_->IsInitialized();
145+
<< inp->var_->IsInitialized() << " stop_gradient "
146+
<< inp->IsStopGradient();
141147
}
142148
}
143149

@@ -189,6 +195,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
189195

190196
op->grad_input_vars_.resize(op->grad_op_descs_.size());
191197
op->grad_output_vars_.resize(op->grad_op_descs_.size());
198+
192199
for (size_t i = 0; i < op->grad_op_descs_.size(); ++i) {
193200
framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
194201
for (auto it : grad_op_desc->Inputs()) {
@@ -201,7 +208,6 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
201208
PADDLE_ENFORCE(fwd_var_it != vars.end());
202209
// Forward inputs or outputs.
203210
grad_in_vars.push_back(fwd_var_it->second->var_);
204-
vars_saved_for_backward.insert(it.first);
205211
} else {
206212
VarBase* var = vars[var_it->second];
207213
if (!var->grads_->var_->IsInitialized()) {
@@ -211,6 +217,8 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
211217
// Douts.
212218
grad_in_vars.push_back(var->grads_->var_);
213219
}
220+
221+
vars_saved_for_backward.insert(it.first);
214222
}
215223
}
216224

paddle/fluid/pybind/imperative.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Layer : public imperative::Layer {
3333
}
3434
};
3535

36-
class PyOpBase : public imperative::OpBase {
36+
class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
3737
public:
3838
using imperative::OpBase::OpBase; // Inherit constructors
3939
};

paddle/fluid/pybind/protobuf.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,6 @@ void BindBlockDesc(pybind11::module *m) {
189189
return self.HasVar(name);
190190
},
191191
pybind11::return_value_policy::reference)
192-
.def("_clear_block", [](pd::BlockDesc &self) { return self.Clear(); },
193-
pybind11::return_value_policy::reference)
194192
.def("_rename_var",
195193
[](pd::BlockDesc &self, const pybind11::bytes &byte_name,
196194
const pybind11::bytes &byte_name_new) {

0 commit comments

Comments
 (0)