@@ -69,8 +69,14 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar;
69
69
class IdentityActivation : public ActivationFunction {
70
70
public:
71
71
static const std::string name;
72
- void forward (Argument& act) { (void )act; }
73
- void backward (Argument& act) { (void )act; }
72
+ Error __must_check forward (Argument& act) {
73
+ (void )act;
74
+ return Error ();
75
+ }
76
+ Error __must_check backward (Argument& act) {
77
+ (void )act;
78
+ return Error ();
79
+ }
74
80
const std::string& getName () const { return name; }
75
81
};
76
82
const std::string IdentityActivation::name = " " ;
@@ -86,8 +92,14 @@ static InitFunction __reg_activation__identity([] {
86
92
* \f]
87
93
*/
88
94
BEGIN_DEFINE_ACTIVATION (sigmoid)
89
- void forward (Argument& act) { act.value ->sigmoid (*act.value ); }
90
- void backward (Argument& act) { act.grad ->sigmoidDerivative (*act.value ); }
95
+ Error __must_check forward (Argument& act) {
96
+ act.value ->sigmoid (*act.value );
97
+ return Error ();
98
+ }
99
+ Error __must_check backward (Argument& act) {
100
+ act.grad ->sigmoidDerivative (*act.value );
101
+ return Error ();
102
+ }
91
103
END_DEFINE_ACTIVATION (sigmoid)
92
104
93
105
/* *
@@ -103,9 +115,12 @@ MatrixPtr sftMaxDot_;
103
115
MatrixPtr one_;
104
116
105
117
public:
106
- void forward (Argument& act) { act.value ->softmax (*act.value ); }
118
+ Error __must_check forward (Argument& act) {
119
+ act.value ->softmax (*act.value );
120
+ return Error ();
121
+ }
107
122
108
- void backward (Argument& act) {
123
+ Error __must_check backward (Argument& act) {
109
124
MatrixPtr outputV = act.value ;
110
125
MatrixPtr outputG = act.grad ;
111
126
@@ -137,6 +152,7 @@ void backward(Argument& act) {
137
152
138
153
act.grad ->softmaxDerivative (*act.value , *sftMaxSum_);
139
154
}
155
+ return Error ();
140
156
}
141
157
END_DEFINE_ACTIVATION (softmax)
142
158
@@ -151,8 +167,11 @@ ACTIVATION_CLASS_NAME(softmax) softmax_;
151
167
Argument argument_;
152
168
153
169
public:
154
- void forward (Argument& act) {
155
- CHECK_EQ (act.value ->getWidth (), 1UL );
170
+ Error __must_check forward (Argument& act) {
171
+ if (act.value ->getWidth () != 1UL ) {
172
+ return Error (
173
+ " Input width for each timestep of sequence softmax should be 1" );
174
+ }
156
175
157
176
if (!argument_.value ) {
158
177
argument_.value = Matrix::create (nullptr ,
@@ -169,10 +188,14 @@ void forward(Argument& act) {
169
188
170
189
auto starts = act.sequenceStartPositions ->getVector (useGpu (act.deviceId ));
171
190
act.value ->sequenceSoftmax (*act.value , *starts);
191
+ return Error ();
172
192
}
173
193
174
- void backward (Argument& act) {
175
- CHECK_EQ (act.grad ->getWidth (), 1UL );
194
+ Error __must_check backward (Argument& act) {
195
+ if (act.value ->getWidth () != 1UL ) {
196
+ return Error (
197
+ " Input width for each timestep of sequence softmax should be 1" );
198
+ }
176
199
177
200
size_t numSequences = act.getNumSequences ();
178
201
const int * starts = act.sequenceStartPositions ->getData (false );
@@ -184,8 +207,10 @@ void backward(Argument& act) {
184
207
argument_.value ->setData (act.value ->getData () + offset, 1UL , size);
185
208
argument_.grad ->setData (act.grad ->getData () + offset, 1UL , size);
186
209
187
- softmax_.backward (argument_);
210
+ Error status = softmax_.backward (argument_);
211
+ if (!status) return status;
188
212
}
213
+ return Error ();
189
214
}
190
215
END_DEFINE_ACTIVATION (sequence_softmax)
191
216
@@ -200,9 +225,15 @@ END_DEFINE_ACTIVATION(sequence_softmax)
200
225
* 0 otherwise.
201
226
*/
202
227
BEGIN_DEFINE_ACTIVATION (relu)
203
- void forward (Argument& act) { act.value ->relu (*act.value ); }
228
+ Error __must_check forward (Argument& act) {
229
+ act.value ->relu (*act.value );
230
+ return Error ();
231
+ }
204
232
205
- void backward (Argument& act) { act.grad ->reluDerivative (*act.value ); }
233
+ Error __must_check backward (Argument& act) {
234
+ act.grad ->reluDerivative (*act.value );
235
+ return Error ();
236
+ }
206
237
END_DEFINE_ACTIVATION (relu)
207
238
208
239
/* *
@@ -219,9 +250,15 @@ END_DEFINE_ACTIVATION(relu)
219
250
* TODO(yuyang18): Remove magic number 24 or make it configuable.
220
251
*/
221
252
BEGIN_DEFINE_ACTIVATION (brelu)
222
- void forward (Argument& act) { act.value ->brelu (*act.value ); }
253
+ Error __must_check forward (Argument& act) {
254
+ act.value ->brelu (*act.value );
255
+ return Error ();
256
+ }
223
257
224
- void backward (Argument& act) { act.grad ->breluDerivative (*act.value ); }
258
+ Error __must_check backward (Argument& act) {
259
+ act.grad ->breluDerivative (*act.value );
260
+ return Error ();
261
+ }
225
262
END_DEFINE_ACTIVATION (brelu)
226
263
227
264
/* *
@@ -231,9 +268,15 @@ END_DEFINE_ACTIVATION(brelu)
231
268
* \f]
232
269
*/
233
270
BEGIN_DEFINE_ACTIVATION (tanh)
234
- void forward (Argument& act) { act.value ->tanh (*act.value ); }
271
+ Error __must_check forward (Argument& act) {
272
+ act.value ->tanh (*act.value );
273
+ return Error ();
274
+ }
235
275
236
- void backward (Argument& act) { act.grad ->tanhDerivative (*act.value ); }
276
+ Error __must_check backward (Argument& act) {
277
+ act.grad ->tanhDerivative (*act.value );
278
+ return Error ();
279
+ }
237
280
END_DEFINE_ACTIVATION (tanh)
238
281
239
282
/* *
@@ -248,10 +291,14 @@ real a, b;
248
291
249
292
public:
250
293
ACTIVATION_CLASS_NAME (stanh)() : a(1.7159 ), b(2 . / 3 .) {}
251
- void forward (Argument& act) { act.value ->scaledTanh (*act.value , a, b); }
294
+ Error __must_check forward (Argument& act) {
295
+ act.value ->scaledTanh (*act.value , a, b);
296
+ return Error ();
297
+ }
252
298
253
- void backward (Argument& act) {
299
+ Error __must_check backward (Argument& act) {
254
300
act.grad ->scaledTanhDerivative (*act.value , a, b);
301
+ return Error ();
255
302
}
256
303
END_DEFINE_ACTIVATION (stanh)
257
304
@@ -262,9 +309,15 @@ END_DEFINE_ACTIVATION(stanh)
262
309
* \f]
263
310
*/
264
311
BEGIN_DEFINE_ACTIVATION (softrelu)
265
- void forward (Argument& act) { act.value ->softrelu (*act.value ); }
312
+ Error __must_check forward (Argument& act) {
313
+ act.value ->softrelu (*act.value );
314
+ return Error ();
315
+ }
266
316
267
- void backward (Argument& act) { act.grad ->softreluDerivative (*act.value ); }
317
+ Error __must_check backward (Argument& act) {
318
+ act.grad ->softreluDerivative (*act.value );
319
+ return Error ();
320
+ }
268
321
END_DEFINE_ACTIVATION (softrelu)
269
322
270
323
/* *
@@ -280,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu)
280
333
* 0 if z=0
281
334
*/
282
335
BEGIN_DEFINE_ACTIVATION (abs)
283
- void forward (Argument& act) {
336
+ Error __must_check forward (Argument& act) {
284
337
SetDevice device (act.deviceId );
285
338
Matrix::resizeOrCreate (act.in ,
286
339
act.value ->getHeight (),
@@ -290,9 +343,13 @@ void forward(Argument& act) {
290
343
291
344
act.in ->copyFrom (*act.value );
292
345
act.value ->abs2 (*act.value );
346
+ return Error ();
293
347
}
294
348
295
- void backward (Argument& act) { act.grad ->absDerivative (*act.in ); }
349
+ Error __must_check backward (Argument& act) {
350
+ act.grad ->absDerivative (*act.in );
351
+ return Error ();
352
+ }
296
353
END_DEFINE_ACTIVATION (abs)
297
354
298
355
/* *
@@ -302,7 +359,7 @@ END_DEFINE_ACTIVATION(abs)
302
359
* \f]
303
360
*/
304
361
BEGIN_DEFINE_ACTIVATION (square)
305
- void forward (Argument& act) {
362
+ Error __must_check forward (Argument& act) {
306
363
SetDevice device (act.deviceId );
307
364
Matrix::resizeOrCreate (act.in ,
308
365
act.value ->getHeight (),
@@ -312,9 +369,13 @@ void forward(Argument& act) {
312
369
313
370
act.in ->copyFrom (*act.value );
314
371
act.value ->square2 (*act.value );
372
+ return Error ();
315
373
}
316
374
317
- void backward (Argument& act) { act.grad ->squareDerivative (*act.in ); }
375
+ Error __must_check backward (Argument& act) {
376
+ act.grad ->squareDerivative (*act.in );
377
+ return Error ();
378
+ }
318
379
END_DEFINE_ACTIVATION (square)
319
380
320
381
/* *
@@ -324,9 +385,15 @@ END_DEFINE_ACTIVATION(square)
324
385
* \f]
325
386
*/
326
387
BEGIN_DEFINE_ACTIVATION (exponential)
327
- void forward (Argument& act) { act.value ->exp2 (*act.value ); }
388
+ Error __must_check forward (Argument& act) {
389
+ act.value ->exp2 (*act.value );
390
+ return Error ();
391
+ }
328
392
329
- void backward (Argument& act) { act.grad ->expDerivative (*act.value ); }
393
+ Error __must_check backward (Argument& act) {
394
+ act.grad ->expDerivative (*act.value );
395
+ return Error ();
396
+ }
330
397
END_DEFINE_ACTIVATION (exponential)
331
398
332
399
/* *
@@ -336,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential)
336
403
* \f]
337
404
*/
338
405
BEGIN_DEFINE_ACTIVATION (log)
339
- void forward (Argument& act) {
406
+ Error __must_check forward (Argument& act) {
340
407
SetDevice device (act.deviceId );
341
408
Matrix::resizeOrCreate (act.in ,
342
409
act.value ->getHeight (),
@@ -346,9 +413,13 @@ void forward(Argument& act) {
346
413
347
414
act.in ->copyFrom (*act.value );
348
415
act.value ->log2 (*act.value );
416
+ return Error ();
349
417
}
350
418
351
- void backward (Argument& act) { act.grad ->dotDiv (*act.grad , *act.in ); }
419
+ Error __must_check backward (Argument& act) {
420
+ act.grad ->dotDiv (*act.grad , *act.in );
421
+ return Error ();
422
+ }
352
423
END_DEFINE_ACTIVATION (log)
353
424
354
425
ActivationFunction* ActivationFunction::create (const std::string& type) {
0 commit comments