17
17
#include < string>
18
18
#include < vector>
19
19
#include " paddle/fluid/imperative/layer.h"
20
+ #include " paddle/fluid/imperative/prepared_operator.h"
20
21
#include " paddle/fluid/imperative/tracer.h"
21
22
22
23
#include " paddle/fluid/framework/op_registry.h"
@@ -32,7 +33,17 @@ bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
32
33
for (const auto & name_pair : ins) {
33
34
for (const auto & var_base : name_pair.second ) {
34
35
if (!var_base->OverridedStopGradient ()) {
35
- PassStopGradient (outs, var_base->OverridedStopGradient ());
36
+ for (const auto & pair : outs) {
37
+ for (const auto & var : pair.second ) {
38
+ if (var) {
39
+ var->SetOverridedStopGradient (false );
40
+ SetForwardDataTypeOfGradVar (var);
41
+ VLOG (3 ) << " Set output: " << var->Name ()
42
+ << " 's OverridedStopGradient as "
43
+ << var->OverridedStopGradient ();
44
+ }
45
+ }
46
+ }
36
47
return true ;
37
48
}
38
49
}
@@ -78,28 +89,36 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
78
89
// process args,`input_vars` only collect `imperative::VarBase`
79
90
if (!args.empty ()) {
80
91
for (auto ptr = args.begin (); ptr != args.end (); ptr++) {
81
- try {
82
- if (Py_None != ptr->ptr ()) {
92
+ // Only collect Tensor type in 'args' and pass them to backward. Ignore
93
+ // other types of input temporarily.
94
+ if (py::isinstance<imperative::VarBase>(*ptr)) {
95
+ try {
83
96
auto a = ptr->cast <std::shared_ptr<VarBase>>();
84
97
input_vars.push_back (a);
98
+ } catch (py::cast_error& err) {
99
+ PADDLE_THROW (platform::errors::InvalidArgument (
100
+ " The `PyLayer.forward` function contains invalid argument, the "
101
+ " `%s` type argument can not be cast into `Tensor`." ,
102
+ ptr->ptr ()->ob_type ->tp_name ));
85
103
}
86
- } catch (py::cast_error& err) {
87
- // Only collect Tensor type in 'args' and pass them to backward. Ignore
88
- // other types of input temporarily.
89
104
}
90
105
}
91
106
}
92
107
// process kwargs, only collect `imperative::VarBase`
93
108
if (!kwargs.empty ()) {
94
109
for (auto ptr = kwargs.begin (); ptr != kwargs.end (); ptr++) {
95
- try {
96
- if (Py_None != ptr->second .ptr ()) {
110
+ // Only collect Tensor type in 'kwargs' and pass them to backward.
111
+ // Ignore other types of input temporarily.
112
+ if (py::isinstance<imperative::VarBase>(*ptr->second )) {
113
+ try {
97
114
auto a = ptr->second .cast <std::shared_ptr<VarBase>>();
98
115
input_vars.push_back (a);
116
+ } catch (py::cast_error&) {
117
+ PADDLE_THROW (platform::errors::InvalidArgument (
118
+ " The `PyLayer.forward` function contains invalid argument, the "
119
+ " `%s` type argument can not be cast into `Tensor`." ,
120
+ ptr->second .ptr ()->ob_type ->tp_name ));
99
121
}
100
- } catch (py::cast_error&) {
101
- // Only collect Tensor type in 'kwargs' and pass them to backward.
102
- // Ignore other types of input temporarily.
103
122
}
104
123
}
105
124
}
@@ -110,33 +129,35 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
110
129
PyList_Check (result_forward.ptr ())) {
111
130
auto tuple_result = result_forward.cast <py::tuple>();
112
131
for (size_t i = 0 ; i < tuple_result.size (); i++) {
113
- if (Py_None != tuple_result[i].ptr ()) {
132
+ // Only collect Tensor type of output and pass them to backward.
133
+ // Ignore other types of input temporarily.
134
+ if (py::isinstance<imperative::VarBase>(tuple_result[i])) {
114
135
try {
115
136
auto temp_out =
116
137
tuple_result[i].cast <std::shared_ptr<imperative::VarBase>>();
117
138
output_vars.push_back (temp_out);
118
139
} catch (py::cast_error&) {
119
- // Only collect Tensor type in 'kwargs' and pass them to backward.
120
- // Ignore other types of input temporarily.
140
+ PADDLE_THROW (platform::errors::InvalidArgument (
141
+ " The `PyLayer.forward` function returns invalid argument, the "
142
+ " `%s` type argument can not be cast into `Tensor`." ,
143
+ tuple_result[i].ptr ()->ob_type ->tp_name ));
121
144
}
122
- } else {
123
- // Only collect Tensor type in 'kwargs' and pass them to backward.
124
- // Ignore other types of input temporarily.
125
145
}
126
146
}
127
147
} else {
128
- if (Py_None != result_forward.ptr ()) {
148
+ // Only collect Tensor type of output and pass them to backward.
149
+ // Ignore other types of input temporarily.
150
+ if (py::isinstance<imperative::VarBase>(result_forward)) {
129
151
try {
130
152
auto temp_out =
131
153
result_forward.cast <std::shared_ptr<imperative::VarBase>>();
132
154
output_vars.push_back (temp_out);
133
155
} catch (py::cast_error&) {
134
- // Only collect Tensor type in 'kwargs' and pass them to backward.
135
- // Ignore other types of input temporarily.
156
+ PADDLE_THROW (platform::errors::InvalidArgument (
157
+ " The `PyLayer.forward` function returns invalid argument, the `%s` "
158
+ " type argument can not be cast into `Tensor`." ,
159
+ result_forward.ptr ()->ob_type ->tp_name ));
136
160
}
137
- } else {
138
- // Only collect Tensor type in 'kwargs' and pass them to backward.
139
- // Ignore other types of input temporarily.
140
161
}
141
162
}
142
163
if (output_vars.size () == 0 ) {
0 commit comments