Skip to content

Commit 7f0ad62

Browse files
authored
Merge pull request #1149 from reyoung/feature/ErrorHandlingInPaddle
Feature/error handling in paddle
2 parents 3fff0af + 843fb2e commit 7f0ad62

File tree

12 files changed

+340
-58
lines changed

12 files changed

+340
-58
lines changed

paddle/gserver/activations/ActivationFunction.cpp

Lines changed: 100 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,14 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar;
6969
class IdentityActivation : public ActivationFunction {
7070
public:
7171
static const std::string name;
72-
void forward(Argument& act) { (void)act; }
73-
void backward(Argument& act) { (void)act; }
72+
Error __must_check forward(Argument& act) {
73+
(void)act;
74+
return Error();
75+
}
76+
Error __must_check backward(Argument& act) {
77+
(void)act;
78+
return Error();
79+
}
7480
const std::string& getName() const { return name; }
7581
};
7682
const std::string IdentityActivation::name = "";
@@ -86,8 +92,14 @@ static InitFunction __reg_activation__identity([] {
8692
* \f]
8793
*/
8894
BEGIN_DEFINE_ACTIVATION(sigmoid)
89-
void forward(Argument& act) { act.value->sigmoid(*act.value); }
90-
void backward(Argument& act) { act.grad->sigmoidDerivative(*act.value); }
95+
Error __must_check forward(Argument& act) {
96+
act.value->sigmoid(*act.value);
97+
return Error();
98+
}
99+
Error __must_check backward(Argument& act) {
100+
act.grad->sigmoidDerivative(*act.value);
101+
return Error();
102+
}
91103
END_DEFINE_ACTIVATION(sigmoid)
92104

93105
/**
@@ -103,9 +115,12 @@ MatrixPtr sftMaxDot_;
103115
MatrixPtr one_;
104116

105117
public:
106-
void forward(Argument& act) { act.value->softmax(*act.value); }
118+
Error __must_check forward(Argument& act) {
119+
act.value->softmax(*act.value);
120+
return Error();
121+
}
107122

108-
void backward(Argument& act) {
123+
Error __must_check backward(Argument& act) {
109124
MatrixPtr outputV = act.value;
110125
MatrixPtr outputG = act.grad;
111126

@@ -137,6 +152,7 @@ void backward(Argument& act) {
137152

138153
act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
139154
}
155+
return Error();
140156
}
141157
END_DEFINE_ACTIVATION(softmax)
142158

@@ -151,8 +167,11 @@ ACTIVATION_CLASS_NAME(softmax) softmax_;
151167
Argument argument_;
152168

153169
public:
154-
void forward(Argument& act) {
155-
CHECK_EQ(act.value->getWidth(), 1UL);
170+
Error __must_check forward(Argument& act) {
171+
if (act.value->getWidth() != 1UL) {
172+
return Error(
173+
"Input width for each timestep of sequence softmax should be 1");
174+
}
156175

157176
if (!argument_.value) {
158177
argument_.value = Matrix::create(nullptr,
@@ -169,10 +188,14 @@ void forward(Argument& act) {
169188

170189
auto starts = act.sequenceStartPositions->getVector(useGpu(act.deviceId));
171190
act.value->sequenceSoftmax(*act.value, *starts);
191+
return Error();
172192
}
173193

174-
void backward(Argument& act) {
175-
CHECK_EQ(act.grad->getWidth(), 1UL);
194+
Error __must_check backward(Argument& act) {
195+
if (act.value->getWidth() != 1UL) {
196+
return Error(
197+
"Input width for each timestep of sequence softmax should be 1");
198+
}
176199

177200
size_t numSequences = act.getNumSequences();
178201
const int* starts = act.sequenceStartPositions->getData(false);
@@ -184,8 +207,10 @@ void backward(Argument& act) {
184207
argument_.value->setData(act.value->getData() + offset, 1UL, size);
185208
argument_.grad->setData(act.grad->getData() + offset, 1UL, size);
186209

187-
softmax_.backward(argument_);
210+
Error status = softmax_.backward(argument_);
211+
if (!status) return status;
188212
}
213+
return Error();
189214
}
190215
END_DEFINE_ACTIVATION(sequence_softmax)
191216

@@ -200,9 +225,15 @@ END_DEFINE_ACTIVATION(sequence_softmax)
200225
* 0 otherwise.
201226
*/
202227
BEGIN_DEFINE_ACTIVATION(relu)
203-
void forward(Argument& act) { act.value->relu(*act.value); }
228+
Error __must_check forward(Argument& act) {
229+
act.value->relu(*act.value);
230+
return Error();
231+
}
204232

205-
void backward(Argument& act) { act.grad->reluDerivative(*act.value); }
233+
Error __must_check backward(Argument& act) {
234+
act.grad->reluDerivative(*act.value);
235+
return Error();
236+
}
206237
END_DEFINE_ACTIVATION(relu)
207238

208239
/**
@@ -219,9 +250,15 @@ END_DEFINE_ACTIVATION(relu)
219250
* TODO(yuyang18): Remove magic number 24 or make it configuable.
220251
*/
221252
BEGIN_DEFINE_ACTIVATION(brelu)
222-
void forward(Argument& act) { act.value->brelu(*act.value); }
253+
Error __must_check forward(Argument& act) {
254+
act.value->brelu(*act.value);
255+
return Error();
256+
}
223257

224-
void backward(Argument& act) { act.grad->breluDerivative(*act.value); }
258+
Error __must_check backward(Argument& act) {
259+
act.grad->breluDerivative(*act.value);
260+
return Error();
261+
}
225262
END_DEFINE_ACTIVATION(brelu)
226263

227264
/**
@@ -231,9 +268,15 @@ END_DEFINE_ACTIVATION(brelu)
231268
* \f]
232269
*/
233270
BEGIN_DEFINE_ACTIVATION(tanh)
234-
void forward(Argument& act) { act.value->tanh(*act.value); }
271+
Error __must_check forward(Argument& act) {
272+
act.value->tanh(*act.value);
273+
return Error();
274+
}
235275

236-
void backward(Argument& act) { act.grad->tanhDerivative(*act.value); }
276+
Error __must_check backward(Argument& act) {
277+
act.grad->tanhDerivative(*act.value);
278+
return Error();
279+
}
237280
END_DEFINE_ACTIVATION(tanh)
238281

239282
/**
@@ -248,10 +291,14 @@ real a, b;
248291

249292
public:
250293
ACTIVATION_CLASS_NAME(stanh)() : a(1.7159), b(2. / 3.) {}
251-
void forward(Argument& act) { act.value->scaledTanh(*act.value, a, b); }
294+
Error __must_check forward(Argument& act) {
295+
act.value->scaledTanh(*act.value, a, b);
296+
return Error();
297+
}
252298

253-
void backward(Argument& act) {
299+
Error __must_check backward(Argument& act) {
254300
act.grad->scaledTanhDerivative(*act.value, a, b);
301+
return Error();
255302
}
256303
END_DEFINE_ACTIVATION(stanh)
257304

@@ -262,9 +309,15 @@ END_DEFINE_ACTIVATION(stanh)
262309
* \f]
263310
*/
264311
BEGIN_DEFINE_ACTIVATION(softrelu)
265-
void forward(Argument& act) { act.value->softrelu(*act.value); }
312+
Error __must_check forward(Argument& act) {
313+
act.value->softrelu(*act.value);
314+
return Error();
315+
}
266316

267-
void backward(Argument& act) { act.grad->softreluDerivative(*act.value); }
317+
Error __must_check backward(Argument& act) {
318+
act.grad->softreluDerivative(*act.value);
319+
return Error();
320+
}
268321
END_DEFINE_ACTIVATION(softrelu)
269322

270323
/**
@@ -280,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu)
280333
* 0 if z=0
281334
*/
282335
BEGIN_DEFINE_ACTIVATION(abs)
283-
void forward(Argument& act) {
336+
Error __must_check forward(Argument& act) {
284337
SetDevice device(act.deviceId);
285338
Matrix::resizeOrCreate(act.in,
286339
act.value->getHeight(),
@@ -290,9 +343,13 @@ void forward(Argument& act) {
290343

291344
act.in->copyFrom(*act.value);
292345
act.value->abs2(*act.value);
346+
return Error();
293347
}
294348

295-
void backward(Argument& act) { act.grad->absDerivative(*act.in); }
349+
Error __must_check backward(Argument& act) {
350+
act.grad->absDerivative(*act.in);
351+
return Error();
352+
}
296353
END_DEFINE_ACTIVATION(abs)
297354

298355
/**
@@ -302,7 +359,7 @@ END_DEFINE_ACTIVATION(abs)
302359
* \f]
303360
*/
304361
BEGIN_DEFINE_ACTIVATION(square)
305-
void forward(Argument& act) {
362+
Error __must_check forward(Argument& act) {
306363
SetDevice device(act.deviceId);
307364
Matrix::resizeOrCreate(act.in,
308365
act.value->getHeight(),
@@ -312,9 +369,13 @@ void forward(Argument& act) {
312369

313370
act.in->copyFrom(*act.value);
314371
act.value->square2(*act.value);
372+
return Error();
315373
}
316374

317-
void backward(Argument& act) { act.grad->squareDerivative(*act.in); }
375+
Error __must_check backward(Argument& act) {
376+
act.grad->squareDerivative(*act.in);
377+
return Error();
378+
}
318379
END_DEFINE_ACTIVATION(square)
319380

320381
/**
@@ -324,9 +385,15 @@ END_DEFINE_ACTIVATION(square)
324385
* \f]
325386
*/
326387
BEGIN_DEFINE_ACTIVATION(exponential)
327-
void forward(Argument& act) { act.value->exp2(*act.value); }
388+
Error __must_check forward(Argument& act) {
389+
act.value->exp2(*act.value);
390+
return Error();
391+
}
328392

329-
void backward(Argument& act) { act.grad->expDerivative(*act.value); }
393+
Error __must_check backward(Argument& act) {
394+
act.grad->expDerivative(*act.value);
395+
return Error();
396+
}
330397
END_DEFINE_ACTIVATION(exponential)
331398

332399
/**
@@ -336,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential)
336403
* \f]
337404
*/
338405
BEGIN_DEFINE_ACTIVATION(log)
339-
void forward(Argument& act) {
406+
Error __must_check forward(Argument& act) {
340407
SetDevice device(act.deviceId);
341408
Matrix::resizeOrCreate(act.in,
342409
act.value->getHeight(),
@@ -346,9 +413,13 @@ void forward(Argument& act) {
346413

347414
act.in->copyFrom(*act.value);
348415
act.value->log2(*act.value);
416+
return Error();
349417
}
350418

351-
void backward(Argument& act) { act.grad->dotDiv(*act.grad, *act.in); }
419+
Error __must_check backward(Argument& act) {
420+
act.grad->dotDiv(*act.grad, *act.in);
421+
return Error();
422+
}
352423
END_DEFINE_ACTIVATION(log)
353424

354425
ActivationFunction* ActivationFunction::create(const std::string& type) {

paddle/gserver/activations/ActivationFunction.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#pragma once
1616
#include <string>
1717
#include <vector>
18+
#include "paddle/utils/Error.h"
1819

1920
namespace paddle {
2021

@@ -48,7 +49,7 @@ class ActivationFunction {
4849
*
4950
* Usually, act is Layer::output_
5051
*/
51-
virtual void forward(Argument& act) = 0;
52+
virtual Error __must_check forward(Argument& act) = 0;
5253

5354
/**
5455
* @brief Backward propagaion
@@ -57,7 +58,7 @@ class ActivationFunction {
5758
* - Before calling backward(), act.grad = dE / dy, where E is the error/cost
5859
* - After backward() returns, act.grad = dE / dx = (dE/dy) * (dy/dx)
5960
*/
60-
virtual void backward(Argument& act) = 0;
61+
virtual Error __must_check backward(Argument& act) = 0;
6162

6263
virtual const std::string& getName() const = 0;
6364
};

paddle/gserver/layers/Layer.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/utils/Util.h"
1616

1717
#include "paddle/math/SparseMatrix.h"
18+
#include "paddle/utils/Error.h"
1819
#include "paddle/utils/Logging.h"
1920

2021
#include "AddtoLayer.h"
@@ -334,7 +335,8 @@ void Layer::showOutputStats() {
334335

335336
void Layer::forwardActivation() {
336337
/* activation */
337-
activation_->forward(output_);
338+
auto status = activation_->forward(output_);
339+
status.check();
338340

339341
/* dropout */
340342
if (config_.drop_rate() > 0) {
@@ -372,7 +374,8 @@ void Layer::backwardActivation() {
372374
oGrad->dotMul(*oGrad, *dropOutMask_);
373375
}
374376

375-
activation_->backward(output_);
377+
auto status = activation_->backward(output_);
378+
status.check();
376379
}
377380

378381
void Layer::forwardDropOut() {

paddle/gserver/layers/MDLstmLayer.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start,
506506
*frameState_[start + preOffsetV[i]].value, *checkFgOneDim, 1.0, 1.0);
507507
}
508508
}
509-
activationGate_->forward(frameInputGate_[idxCurr]);
510-
activationGate_->forward(frameForgetGate_[idxCurr]);
511-
activation_->forward(frameInputNode_[idxCurr]);
509+
auto status = activationGate_->forward(frameInputGate_[idxCurr]);
510+
status.check();
511+
status = activationGate_->forward(frameForgetGate_[idxCurr]);
512+
status.check();
513+
status = activation_->forward(frameInputNode_[idxCurr]);
514+
status.check();
512515

513516
frameState_[idxCurr].value->zeroMem();
514517
for (int i = 0; i < numDims_; i++) {
@@ -530,10 +533,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start,
530533

531534
frameOutputGate_[idxCurr].value->addDotMul(
532535
*frameState_[idxCurr].value, *checkOg_, 1.0, 1.0);
533-
activationGate_->forward(frameOutputGate_[idxCurr]);
536+
status = activationGate_->forward(frameOutputGate_[idxCurr]);
537+
status.check();
534538

535539
framePreOutput_[idxCurr].value->copyFrom(*(frameState_[idxCurr].value));
536-
activationState_->forward(framePreOutput_[idxCurr]);
540+
status = activationState_->forward(framePreOutput_[idxCurr]);
541+
status.check();
537542

538543
frameOutput_[idxCurr].value->dotMul(*framePreOutput_[idxCurr].value,
539544
*frameOutputGate_[idxCurr].value);
@@ -640,12 +645,12 @@ void MDLstmLayer::backwardGate2OutputSequence(int start,
640645

641646
framePreOutput_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad,
642647
*frameOutputGate_[idxCurr].value);
643-
activationState_->backward(framePreOutput_[idxCurr]);
648+
activationState_->backward(framePreOutput_[idxCurr]).check();
644649
frameState_[idxCurr].grad->copyFrom(*(framePreOutput_[idxCurr].grad));
645650

646651
frameOutputGate_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad,
647652
*framePreOutput_[idxCurr].value);
648-
activationGate_->backward(frameOutputGate_[idxCurr]);
653+
activationGate_->backward(frameOutputGate_[idxCurr]).check();
649654

650655
frameState_[idxCurr].grad->addDotMul(
651656
*frameOutputGate_[idxCurr].grad, *checkOg_, 1.0, 1.0);
@@ -702,9 +707,9 @@ void MDLstmLayer::backwardGate2OutputSequence(int start,
702707
}
703708
}
704709

705-
activationGate_->backward(frameInputGate_[idxCurr]);
706-
activationGate_->backward(frameForgetGate_[idxCurr]);
707-
activation_->backward(frameInputNode_[idxCurr]);
710+
activationGate_->backward(frameInputGate_[idxCurr]).check();
711+
activationGate_->backward(frameForgetGate_[idxCurr]).check();
712+
activation_->backward(frameInputNode_[idxCurr]).check();
708713

709714
if (bias_->getWGrad()) {
710715
for (int i = 0; i < numDims_; i++) {

0 commit comments

Comments
 (0)