File tree Expand file tree Collapse file tree 3 files changed +18
-7
lines changed Expand file tree Collapse file tree 3 files changed +18
-7
lines changed Original file line number Diff line number Diff line change @@ -63,15 +63,16 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
63
63
}
64
64
}
65
65
66
- py::object PyLayerApply (const platform::Place& place, const py::object & cls,
66
+ py::object PyLayerApply (const platform::Place& place, const py::handle & cls,
67
67
const py::args args, const py::kwargs kwargs) {
68
+ py::gil_scoped_acquire guard;
68
69
auto bk_function = cls.attr (" _backward_function" );
69
70
auto context = bk_function ();
70
71
auto forward = cls.attr (" forward" );
71
72
72
73
auto result_forward = forward (context, *args, **kwargs);
73
74
std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
74
- std::make_shared<operators::PyLayerContext>(context.release (). ptr ());
75
+ std::make_shared<operators::PyLayerContext>(context.ptr ());
75
76
// make inputs to varbase
76
77
std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
77
78
// process args,`input_vars` only collect `imperative::VarBase`
Original file line number Diff line number Diff line change @@ -157,9 +157,12 @@ class PyLayerOpKernel : public framework::OpKernel<T> {
157
157
public:
158
158
void Compute (const framework::ExecutionContext &ctx) const override {
159
159
auto &op_ = ctx.GetOp ();
160
- auto pylayer_op = dynamic_cast <const PyLayerOp *>(&op_);
161
- if (pylayer_op) {
162
- auto py_layer_context = pylayer_op->GetPyLayerContext ();
160
+ auto const_pylayer_op = dynamic_cast <const PyLayerOp *>(&op_);
161
+ if (const_pylayer_op) {
162
+ auto pylayer_op = const_cast <PyLayerOp *>(const_pylayer_op);
163
+
164
+ // Release contex after executing the compute
165
+ auto py_layer_context = pylayer_op->ReleasePyLayerContext ();
163
166
py::object bk_ctx (py::handle (py_layer_context->GetMutableCtx ()), true );
164
167
auto &input_vars = ctx.MultiInputVar (" X" );
165
168
auto output_vars = ctx.MultiOutputVar (" Out" );
Original file line number Diff line number Diff line change @@ -34,6 +34,10 @@ class PyLayerContext {
34
34
PyLayerContext () = delete ;
35
35
36
36
PyObject* GetMutableCtx () { return context_; }
37
+ ~PyLayerContext () {
38
+ py::gil_scoped_acquire guard;
39
+ Py_XDECREF (context_);
40
+ }
37
41
38
42
private:
39
43
PyObject* context_;
@@ -58,8 +62,11 @@ class PyLayerOp : public framework::OperatorWithKernel {
58
62
void SetPyLayerContext (const std::shared_ptr<PyLayerContext>& py_context) {
59
63
py_context_ = py_context;
60
64
}
61
- const std::shared_ptr<PyLayerContext>& GetPyLayerContext () const {
62
- return py_context_;
65
+ std::shared_ptr<PyLayerContext> ReleasePyLayerContext () {
66
+ auto temp = py_context_;
67
+ py_context_.reset ();
68
+ VLOG (3 ) << " `py_context_` in the PyLayerOp is released." ;
69
+ return temp;
63
70
}
64
71
65
72
private:
You can’t perform that action at this time.
0 commit comments