Skip to content

Commit 0079e0b

Browse files
authored
[Cherry-Pick] Fix the segfault when using to_tensor in PyLayer. (#33303) (#33518)
修复pylayer 返回to_tensor时触发段错误的bug。 原因: 如果在Python端修改了stop_gradient属性,c++ 端InnerSetOverridedStopGradient 无法修改stop_gradient属性,在c++端调用SetOverridedStopGradient修改stop_gradient属性。 to_tensor产生的tensor的grad var的DataType为默认值(-1),在backward的过程中grad var的DataType不能为默认值(-1),因此在调用ForwardDataType设置grad var的DataType。 原始PR:#33303
1 parent f703461 commit 0079e0b

File tree

3 files changed

+202
-35
lines changed

3 files changed

+202
-35
lines changed

paddle/fluid/imperative/py_layer_fwd.h

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <string>
1818
#include <vector>
1919
#include "paddle/fluid/imperative/layer.h"
20+
#include "paddle/fluid/imperative/prepared_operator.h"
2021
#include "paddle/fluid/imperative/tracer.h"
2122

2223
#include "paddle/fluid/framework/op_registry.h"
@@ -32,7 +33,17 @@ bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
3233
for (const auto& name_pair : ins) {
3334
for (const auto& var_base : name_pair.second) {
3435
if (!var_base->OverridedStopGradient()) {
35-
PassStopGradient(outs, var_base->OverridedStopGradient());
36+
for (const auto& pair : outs) {
37+
for (const auto& var : pair.second) {
38+
if (var) {
39+
var->SetOverridedStopGradient(false);
40+
SetForwardDataTypeOfGradVar(var);
41+
VLOG(3) << "Set output: " << var->Name()
42+
<< "'s OverridedStopGradient as "
43+
<< var->OverridedStopGradient();
44+
}
45+
}
46+
}
3647
return true;
3748
}
3849
}
@@ -78,28 +89,36 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
7889
// process args,`input_vars` only collect `imperative::VarBase`
7990
if (!args.empty()) {
8091
for (auto ptr = args.begin(); ptr != args.end(); ptr++) {
81-
try {
82-
if (Py_None != ptr->ptr()) {
92+
// Only collect Tensor type in 'args' and pass them to backward. Ignore
93+
// other types of input temporarily.
94+
if (py::isinstance<imperative::VarBase>(*ptr)) {
95+
try {
8396
auto a = ptr->cast<std::shared_ptr<VarBase>>();
8497
input_vars.push_back(a);
98+
} catch (py::cast_error& err) {
99+
PADDLE_THROW(platform::errors::InvalidArgument(
100+
"The `PyLayer.forward` function contains invalid argument, the "
101+
"`%s` type argument can not be cast into `Tensor`.",
102+
ptr->ptr()->ob_type->tp_name));
85103
}
86-
} catch (py::cast_error& err) {
87-
// Only collect Tensor type in 'args' and pass them to backward. Ignore
88-
// other types of input temporarily.
89104
}
90105
}
91106
}
92107
// process kwargs, only collect `imperative::VarBase`
93108
if (!kwargs.empty()) {
94109
for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) {
95-
try {
96-
if (Py_None != ptr->second.ptr()) {
110+
// Only collect Tensor type in 'kwargs' and pass them to backward.
111+
// Ignore other types of input temporarily.
112+
if (py::isinstance<imperative::VarBase>(*ptr->second)) {
113+
try {
97114
auto a = ptr->second.cast<std::shared_ptr<VarBase>>();
98115
input_vars.push_back(a);
116+
} catch (py::cast_error&) {
117+
PADDLE_THROW(platform::errors::InvalidArgument(
118+
"The `PyLayer.forward` function contains invalid argument, the "
119+
"`%s` type argument can not be cast into `Tensor`.",
120+
ptr->second.ptr()->ob_type->tp_name));
99121
}
100-
} catch (py::cast_error&) {
101-
// Only collect Tensor type in 'kwargs' and pass them to backward.
102-
// Ignore other types of input temporarily.
103122
}
104123
}
105124
}
@@ -110,33 +129,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
110129
PyList_Check(result_forward.ptr())) {
111130
auto tuple_result = result_forward.cast<py::tuple>();
112131
for (size_t i = 0; i < tuple_result.size(); i++) {
113-
if (Py_None != tuple_result[i].ptr()) {
132+
// Only collect Tensor type of output and pass them to backward.
133+
// Ignore other types of input temporarily.
134+
if (py::isinstance<imperative::VarBase>(tuple_result[i])) {
114135
try {
115136
auto temp_out =
116137
tuple_result[i].cast<std::shared_ptr<imperative::VarBase>>();
117138
output_vars.push_back(temp_out);
118139
} catch (py::cast_error&) {
119-
// Only collect Tensor type in 'kwargs' and pass them to backward.
120-
// Ignore other types of input temporarily.
140+
PADDLE_THROW(platform::errors::InvalidArgument(
141+
"The `PyLayer.forward` function returns invalid argument, the "
142+
"`%s` type argument can not be cast into `Tensor`.",
143+
tuple_result[i].ptr()->ob_type->tp_name));
121144
}
122-
} else {
123-
// Only collect Tensor type in 'kwargs' and pass them to backward.
124-
// Ignore other types of input temporarily.
125145
}
126146
}
127147
} else {
128-
if (Py_None != result_forward.ptr()) {
148+
// Only collect Tensor type of output and pass them to backward.
149+
// Ignore other types of input temporarily.
150+
if (py::isinstance<imperative::VarBase>(result_forward)) {
129151
try {
130152
auto temp_out =
131153
result_forward.cast<std::shared_ptr<imperative::VarBase>>();
132154
output_vars.push_back(temp_out);
133155
} catch (py::cast_error&) {
134-
// Only collect Tensor type in 'kwargs' and pass them to backward.
135-
// Ignore other types of input temporarily.
156+
PADDLE_THROW(platform::errors::InvalidArgument(
157+
"The `PyLayer.forward` function returns invalid argument, the `%s` "
158+
"type argument can not be cast into `Tensor`.",
159+
result_forward.ptr()->ob_type->tp_name));
136160
}
137-
} else {
138-
// Only collect Tensor type in 'kwargs' and pass them to backward.
139-
// Ignore other types of input temporarily.
140161
}
141162
}
142163
if (output_vars.size() == 0) {

paddle/fluid/operators/py_layer_op.cc

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,22 @@ void RunPyObject(py::object *py_object,
6262
for (size_t i = 0; i < result_tuple.size(); i++) {
6363
if ((*outs)[i] != nullptr) {
6464
if (Py_None != result_tuple[i].ptr()) {
65-
try {
66-
auto result_var =
67-
result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>();
68-
*(*outs)[i] = result_var->Var();
69-
} catch (py::cast_error &) {
65+
if (py::isinstance<imperative::VarBase>(result_tuple[i])) {
66+
try {
67+
auto result_var =
68+
result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>();
69+
*(*outs)[i] = result_var->Var();
70+
} catch (py::cast_error &) {
71+
PADDLE_THROW(platform::errors::InvalidArgument(
72+
"The `PyLayer.backward` function returns invalid argument, "
73+
"the `%s` type argument can not be cast into `Tensor`.",
74+
result_tuple[i].ptr()->ob_type->tp_name));
75+
}
76+
} else {
7077
PADDLE_THROW(platform::errors::InvalidArgument(
71-
"The output of `PyLayer.backward` should be `Tensor`."));
78+
"The output of `PyLayer.backward` should be `Tensor`, but "
79+
"received `%s`.",
80+
result_tuple[i].ptr()->ob_type->tp_name));
7281
}
7382
} else {
7483
PADDLE_THROW(platform::errors::InvalidArgument(
@@ -94,13 +103,22 @@ void RunPyObject(py::object *py_object,
94103
}
95104
if ((*outs)[0] != nullptr) {
96105
if (Py_None != py_result.ptr()) {
97-
try {
98-
auto result_var =
99-
py_result.cast<std::shared_ptr<imperative::VarBase>>();
100-
*((*outs)[0]) = result_var->Var();
101-
} catch (py::cast_error &) {
106+
if (py::isinstance<imperative::VarBase>(py_result)) {
107+
try {
108+
auto result_var =
109+
py_result.cast<std::shared_ptr<imperative::VarBase>>();
110+
*((*outs)[0]) = result_var->Var();
111+
} catch (py::cast_error &) {
112+
PADDLE_THROW(platform::errors::InvalidArgument(
113+
"The `PyLayer.backward` function returns invalid argument, the "
114+
"`%s` type argument can not be cast into `Tensor`.",
115+
py_result.ptr()->ob_type->tp_name));
116+
}
117+
} else {
102118
PADDLE_THROW(platform::errors::InvalidArgument(
103-
"The output of `PyLayer.backward` should be `Tensor`."));
119+
"The output of `PyLayer.backward` should be `Tensor`, but "
120+
"received `%s`",
121+
py_result.ptr()->ob_type->tp_name));
104122
}
105123
} else {
106124
PADDLE_THROW(platform::errors::InvalidArgument(

python/paddle/fluid/tests/unittests/test_pylayer_op.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
from paddle.autograd import PyLayer
2222

2323

24+
class FakeTensor(paddle.fluid.core.VarBase):
25+
def __init__(self):
26+
pass
27+
28+
2429
class TestPyLayer(unittest.TestCase):
2530
def test_simple_pylayer_multiple_output(self):
2631
class tanh(PyLayer):
@@ -426,6 +431,129 @@ def backward(ctx, dy):
426431
z = paddle.tanh(data)
427432
z = cus_tanh.apply(data)
428433

434+
def test_return_to_tensor(self):
435+
class Tanh(PyLayer):
436+
@staticmethod
437+
def forward(ctx, x1):
438+
y1 = paddle.tanh(x1)
439+
ctx.save_for_backward(y1)
440+
tensor_1 = paddle.to_tensor([1, 2], dtype='float32')
441+
return y1, 5, None, "helloworld", tensor_1
442+
443+
@staticmethod
444+
def backward(ctx, dy1, dy2):
445+
y1, = ctx.saved_tensor()
446+
re1 = dy1 * (1 - paddle.square(y1))
447+
return dy1
448+
449+
input1 = paddle.randn([2, 3]).astype("float32")
450+
input2 = input1.detach().clone()
451+
input1.stop_gradient = False
452+
input2.stop_gradient = False
453+
z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1)
454+
z.mean().backward()
455+
456+
457+
class TestPyLayerReturnType(unittest.TestCase):
458+
def test_forward_args_fake_tensor(self):
459+
class Tanh(PyLayer):
460+
@staticmethod
461+
def forward(ctx, x1):
462+
y1 = FakeTensor()
463+
return y1, x1
464+
465+
@staticmethod
466+
def backward(ctx, dy1, dy2):
467+
return dy1
468+
469+
input1 = FakeTensor()
470+
471+
with self.assertRaises(ValueError):
472+
y1, y2 = Tanh.apply(input1)
473+
474+
def test_forward_kwargs_fake_tensor(self):
475+
class Tanh(PyLayer):
476+
@staticmethod
477+
def forward(ctx, x1):
478+
479+
return x1
480+
481+
@staticmethod
482+
def backward(ctx, dy1, dy2):
483+
return dy1
484+
485+
input1 = FakeTensor()
486+
487+
with self.assertRaises(ValueError):
488+
y = Tanh.apply(x1=input1)
489+
490+
def test_forward_return_fake_tensor(self):
491+
class Tanh(PyLayer):
492+
@staticmethod
493+
def forward(ctx, x1):
494+
495+
return FakeTensor()
496+
497+
@staticmethod
498+
def backward(ctx, dy1, dy2):
499+
return dy1
500+
501+
input1 = paddle.randn([3, 2])
502+
503+
with self.assertRaises(ValueError):
504+
y = Tanh.apply(x1=input1)
505+
506+
def test_forward_return_fake_tensor_tuple(self):
507+
class Tanh(PyLayer):
508+
@staticmethod
509+
def forward(ctx, x1):
510+
511+
return FakeTensor(), FakeTensor()
512+
513+
@staticmethod
514+
def backward(ctx, dy1, dy2):
515+
return dy1
516+
517+
input1 = paddle.randn([3, 2])
518+
519+
with self.assertRaises(ValueError):
520+
y = Tanh.apply(x1=input1)
521+
522+
def test_backward_return_fake_tensor_tuple(self):
523+
class Tanh(PyLayer):
524+
@staticmethod
525+
def forward(ctx, x1, x2):
526+
return x1 + 1, x1 + 2
527+
528+
@staticmethod
529+
def backward(ctx, dy1, dy2):
530+
531+
return FakeTensor(), 2
532+
533+
input1 = paddle.randn([3, 2])
534+
input1.stop_gradient = False
535+
y, _ = Tanh.apply(input1, 1 + input1)
536+
537+
with self.assertRaises(ValueError):
538+
y.mean().backward()
539+
540+
def test_backward_return_fake_tensor(self):
541+
class Tanh(PyLayer):
542+
@staticmethod
543+
def forward(ctx, x1):
544+
return x1 + 1, x1 + 2
545+
546+
@staticmethod
547+
def backward(ctx, dy1, dy2):
548+
return FakeTensor()
549+
550+
input1 = paddle.randn([3, 2])
551+
input1.stop_gradient = False
552+
y, _ = Tanh.apply(input1)
553+
554+
with self.assertRaises(ValueError):
555+
y.mean().backward()
556+
429557

430558
if __name__ == '__main__':
431559
unittest.main()

0 commit comments

Comments
 (0)