Skip to content

Commit 7e35ef3

Browse files
authored
[Cherry-Pick] Clear 'BasicEngine' when an exception occurs in the backward. (#32546) (#32615)
* clear 'BasicEngine' when an exception occurs in the backward. (#32546) * clear 'BasicEngine' when an exception occurs in the backward. * deal with conflict. * deal with conflict. * forward return any type. (#32661)
1 parent 4f06cd1 commit 7e35ef3

File tree

4 files changed

+80
-45
lines changed

4 files changed

+80
-45
lines changed

paddle/fluid/imperative/basic_engine.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -471,12 +471,20 @@ void BasicEngine::Execute() {
471471

472472
{
473473
VLOG(3) << "Start to execute grad op " << cur_op.Type();
474-
if (tmp_ins_ptr == nullptr) {
475-
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
476-
cur_op.place());
477-
} else {
478-
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, cur_op.Attrs(),
479-
cur_op.place());
474+
try {
475+
if (tmp_ins_ptr == nullptr) {
476+
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
477+
cur_op.place());
478+
} else {
479+
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
480+
cur_op.Attrs(), cur_op.place());
481+
}
482+
} catch (platform::EnforceNotMet& exception) {
483+
Clear();
484+
throw std::move(exception);
485+
} catch (std::exception& ex) {
486+
Clear();
487+
PADDLE_THROW(platform::errors::External("%s", ex.what()));
480488
}
481489
}
482490

paddle/fluid/imperative/py_layer_fwd.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,12 @@ py::object PyLayerApply(const platform::Place& place, const py::object& cls,
115115
tuple_result[i].cast<std::shared_ptr<imperative::VarBase>>();
116116
output_vars.push_back(temp_out);
117117
} catch (py::cast_error&) {
118-
PADDLE_THROW(platform::errors::Unimplemented(
119-
"The output of `PyLayer.forward` should be `Tensor`."));
118+
// Only collect Tensor type in 'kwargs' and pass them to backward.
119+
// Ignore other types of input temporarily.
120120
}
121121
} else {
122-
PADDLE_THROW(platform::errors::Unimplemented(
123-
"The output of `PyLayer.forward` can not be `None`."));
122+
// Only collect Tensor type in 'kwargs' and pass them to backward.
123+
// Ignore other types of input temporarily.
124124
}
125125
}
126126
} else {
@@ -130,14 +130,18 @@ py::object PyLayerApply(const platform::Place& place, const py::object& cls,
130130
result_forward.cast<std::shared_ptr<imperative::VarBase>>();
131131
output_vars.push_back(temp_out);
132132
} catch (py::cast_error&) {
133-
PADDLE_THROW(platform::errors::Unimplemented(
134-
"The output of `PyLayer.forward` should be `Tensor`."));
133+
// Only collect Tensor type in 'kwargs' and pass them to backward.
134+
// Ignore other types of input temporarily.
135135
}
136136
} else {
137-
PADDLE_THROW(platform::errors::Unimplemented(
138-
"The output of `PyLayer.forward` can not be `None`."));
137+
// Only collect Tensor type in 'kwargs' and pass them to backward.
138+
// Ignore other types of input temporarily.
139139
}
140140
}
141+
if (output_vars.size() == 0) {
142+
PADDLE_THROW(platform::errors::InvalidArgument(
143+
"At least one output of `PyLayer.forward` is a `Tensor`."));
144+
}
141145

142146
NameVarBaseMap outs = {{"Out", output_vars}};
143147

paddle/fluid/operators/py_layer_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ void RunPyObject(py::object *py_object,
8686
}
8787
}
8888
} else {
89+
if (1 != outs->size()) {
90+
PADDLE_THROW(platform::errors::InvalidArgument(
91+
"The number of outputs of `PyLayer.backward` should be %d, but "
92+
"received 1.",
93+
outs->size()));
94+
}
8995
if ((*outs)[0] != nullptr) {
9096
if (Py_None != py_result.ptr()) {
9197
try {

python/paddle/fluid/tests/unittests/test_pylayer_op.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def forward(ctx, x1, x2, func1, func2=paddle.square):
3030
y1 = func1(x1)
3131
y2 = func1(x2)
3232
ctx.save_for_backward(y1, y2)
33-
return y1, y2
33+
return y1, 1, y2, None
3434

3535
@staticmethod
3636
def backward(ctx, dy1, dy2):
@@ -44,7 +44,7 @@ def backward(ctx, dy1, dy2):
4444
input1.stop_gradient = False
4545
input2.stop_gradient = False
4646
z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
47-
z = z[0] + z[1]
47+
z = z[0] + z[2]
4848
z.mean().backward()
4949

5050
z2 = paddle.tanh(input2) + paddle.tanh(input2)
@@ -61,7 +61,7 @@ def forward(ctx, x1, x2, func1, func2=paddle.square):
6161
y1 = func1(x1)
6262
y2 = func1(x2)
6363
ctx.save_for_backward(y1, y2)
64-
return y1, y2
64+
return 1, None, y1, y2, ''
6565

6666
@staticmethod
6767
def backward(ctx, dy1, dy2):
@@ -79,7 +79,7 @@ def backward(ctx, dy1, dy2):
7979
input3.stop_gradient = True
8080
input4.stop_gradient = True
8181
z = tanh.apply(input1, input3, paddle.tanh, paddle.square)
82-
z = z[0] + z[1]
82+
z = z[2] + z[3]
8383
z.mean().backward()
8484

8585
z2 = paddle.tanh(input2) + paddle.tanh(input4)
@@ -115,6 +115,27 @@ def backward(ctx, dy1):
115115
self.assertTrue(
116116
np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10)
117117

118+
def test_pylayer_num_output_match(self):
119+
class tanh(PyLayer):
120+
@staticmethod
121+
def forward(
122+
ctx,
123+
x1,
124+
x2, ):
125+
return x1 + x2
126+
127+
@staticmethod
128+
def backward(ctx, dy1):
129+
return dy1 + 1
130+
131+
input1 = paddle.randn([2, 3]).astype("float64")
132+
input2 = input1.detach().clone()
133+
input1.stop_gradient = False
134+
input2.stop_gradient = False
135+
z = tanh.apply(input1, input2)
136+
with self.assertRaises(ValueError):
137+
z.mean().backward()
138+
118139
def test_pylayer_dtype(self):
119140
class tanh(PyLayer):
120141
@staticmethod
@@ -150,21 +171,21 @@ def backward(ctx, *args):
150171
return args
151172

152173
input1 = paddle.randn([2, 3]).astype("float64")
153-
with self.assertRaises(NotImplementedError):
174+
with self.assertRaises(ValueError):
154175
z = Layer_None1.apply(input1)
155176

156177
class Layer_None2(PyLayer):
157178
@staticmethod
158179
def forward(ctx, *args):
159-
return [None, None]
180+
return [None, args[0]]
160181

161182
@staticmethod
162183
def backward(ctx, *args):
163184
return args
164185

165186
input1 = paddle.randn([2, 3]).astype("float64")
166-
with self.assertRaises(NotImplementedError):
167-
z = Layer_None2.apply(input1)
187+
# return None
188+
z = Layer_None2.apply(input1)
168189

169190
class Layer_one1(PyLayer):
170191
@staticmethod
@@ -176,21 +197,22 @@ def backward(ctx, *args):
176197
return args
177198

178199
input1 = paddle.randn([2, 3]).astype("float64")
179-
with self.assertRaises(NotImplementedError):
200+
# At least one output of `PyLayer.backward` is a `Tensor`
201+
with self.assertRaises(ValueError):
180202
z = Layer_one1.apply(input1)
181203

182204
class Layer_one2(PyLayer):
183205
@staticmethod
184206
def forward(ctx, *args):
185-
return [1, 2]
207+
return [1, 2, args[0]]
186208

187209
@staticmethod
188210
def backward(ctx, *args):
189211
return args
190212

191213
input1 = paddle.randn([2, 3]).astype("float64")
192-
with self.assertRaises(NotImplementedError):
193-
z = Layer_one2.apply(input1)
214+
# return int
215+
z = Layer_one2.apply(input1)
194216

195217
class Layer_no_fw(PyLayer):
196218
@staticmethod
@@ -234,8 +256,7 @@ def backward(ctx, dy1):
234256
z = Layer_bk_none1.apply(input2)
235257

236258
with self.assertRaises(ValueError):
237-
with paddle.fluid.dygraph.guard():
238-
z.sum().backward()
259+
z.sum().backward()
239260

240261
class Layer_bk_none2(PyLayer):
241262
@staticmethod
@@ -249,9 +270,9 @@ def backward(ctx, dy1):
249270
input1 = paddle.randn([2, 3]).astype("float64")
250271
input1.stop_gradient = False
251272
z = Layer_bk_none2.apply(input1, input1)
273+
252274
with self.assertRaises(ValueError):
253-
with paddle.fluid.dygraph.guard():
254-
z.mean().backward()
275+
z.mean().backward()
255276

256277
class Layer_bk_one1(PyLayer):
257278
@staticmethod
@@ -265,9 +286,9 @@ def backward(ctx, dy):
265286
input1 = paddle.randn([2, 3]).astype("float64")
266287
input1.stop_gradient = False
267288
z = Layer_bk_one1.apply(input1)
289+
268290
with self.assertRaises(ValueError):
269-
with paddle.fluid.dygraph.guard():
270-
z.mean().backward()
291+
z.mean().backward()
271292

272293
class Layer_bk_one2(PyLayer):
273294
@staticmethod
@@ -280,11 +301,11 @@ def backward(ctx, *args):
280301

281302
input1 = paddle.randn([2, 3]).astype("float64")
282303
input1.stop_gradient = False
304+
283305
y = Layer_bk_one2.apply(input1, input1)
284306
z = y[0] + y[1]
285307
with self.assertRaises(ValueError):
286-
with paddle.fluid.dygraph.guard():
287-
z.mean().backward()
308+
z.mean().backward()
288309

289310
class Layer_no_bk(PyLayer):
290311
@staticmethod
@@ -295,10 +316,9 @@ def forward(ctx, x):
295316
input1.stop_gradient = False
296317
z = Layer_no_bk.apply(input1)
297318

298-
with self.assertRaises(NotImplementedError):
299-
with paddle.fluid.dygraph.guard():
300-
z = z[0] + z[1]
301-
z.mean().backward()
319+
with self.assertRaises(OSError):
320+
z = z[0] + z[1]
321+
z.mean().backward()
302322

303323
class Layer_bk_match(PyLayer):
304324
@staticmethod
@@ -313,9 +333,8 @@ def backward(ctx, dy1, dy2):
313333
input1.stop_gradient = False
314334
z = Layer_bk_match.apply(input1)
315335
with self.assertRaises(ValueError):
316-
with paddle.fluid.dygraph.guard():
317-
z = z[0] + z[1]
318-
z.mean().backward()
336+
z = z[0] + z[1]
337+
z.mean().backward()
319338

320339
def test_pylayer_bk_return_none(self):
321340
class Layer_bk_none1(PyLayer):
@@ -334,8 +353,7 @@ def backward(ctx, dy):
334353
z = Layer_bk_none1.apply(input1, input2)
335354

336355
with self.assertRaises(ValueError):
337-
with paddle.fluid.dygraph.guard():
338-
z.mean().backward()
356+
z.mean().backward()
339357

340358
class Layer_bk_none2(PyLayer):
341359
@staticmethod
@@ -353,8 +371,7 @@ def backward(ctx, *args):
353371
z = Layer_bk_none2.apply(input1, input2)
354372
z = z[0] + z[1]
355373
with self.assertRaises(ValueError):
356-
with paddle.fluid.dygraph.guard():
357-
z.mean().backward()
374+
z.mean().backward()
358375

359376
def test_pylayer_inplace(self):
360377
class cus_tanh(PyLayer):

0 commit comments

Comments
 (0)