Skip to content

Commit c67a5d9

Browse files
authored
pylayer_op:release context after compute. (#32707) (#32744)
修复了py_layer_op由于没有析构PyLayerContext造成内存(显存)泄露的问题。 原始pr:#32707
1 parent 7e35ef3 commit c67a5d9

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

paddle/fluid/imperative/py_layer_fwd.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
6363
}
6464
}
6565

66-
py::object PyLayerApply(const platform::Place& place, const py::object& cls,
66+
py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
6767
const py::args args, const py::kwargs kwargs) {
68+
py::gil_scoped_acquire guard;
6869
auto bk_function = cls.attr("_backward_function");
6970
auto context = bk_function();
7071
auto forward = cls.attr("forward");
7172

7273
auto result_forward = forward(context, *args, **kwargs);
7374
std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
74-
std::make_shared<operators::PyLayerContext>(context.release().ptr());
75+
std::make_shared<operators::PyLayerContext>(context.ptr());
7576
// make inputs to varbase
7677
std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
7778
// process args,`input_vars` only collect `imperative::VarBase`

paddle/fluid/operators/py_layer_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,12 @@ class PyLayerOpKernel : public framework::OpKernel<T> {
157157
public:
158158
void Compute(const framework::ExecutionContext &ctx) const override {
159159
auto &op_ = ctx.GetOp();
160-
auto pylayer_op = dynamic_cast<const PyLayerOp *>(&op_);
161-
if (pylayer_op) {
162-
auto py_layer_context = pylayer_op->GetPyLayerContext();
160+
auto const_pylayer_op = dynamic_cast<const PyLayerOp *>(&op_);
161+
if (const_pylayer_op) {
162+
auto pylayer_op = const_cast<PyLayerOp *>(const_pylayer_op);
163+
164+
// Release contex after executing the compute
165+
auto py_layer_context = pylayer_op->ReleasePyLayerContext();
163166
py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true);
164167
auto &input_vars = ctx.MultiInputVar("X");
165168
auto output_vars = ctx.MultiOutputVar("Out");

paddle/fluid/operators/py_layer_op.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class PyLayerContext {
3434
PyLayerContext() = delete;
3535

3636
PyObject* GetMutableCtx() { return context_; }
37+
~PyLayerContext() {
38+
py::gil_scoped_acquire guard;
39+
Py_XDECREF(context_);
40+
}
3741

3842
private:
3943
PyObject* context_;
@@ -58,8 +62,11 @@ class PyLayerOp : public framework::OperatorWithKernel {
5862
void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) {
5963
py_context_ = py_context;
6064
}
61-
const std::shared_ptr<PyLayerContext>& GetPyLayerContext() const {
62-
return py_context_;
65+
std::shared_ptr<PyLayerContext> ReleasePyLayerContext() {
66+
auto temp = py_context_;
67+
py_context_.reset();
68+
VLOG(3) << "`py_context_` in the PyLayerOp is released.";
69+
return temp;
6370
}
6471

6572
private:

0 commit comments

Comments
 (0)