@@ -192,6 +192,65 @@ void SumOfSquaresCostLayer::backwardImp(Matrix& output,
192
192
outputG.sumOfSquaresBp (output, *label.value );
193
193
}
194
194
195
+ //
196
+ // class SmoothL1CostLayer
197
+ //
198
+
199
+ REGISTER_LAYER (smooth_l1, SmoothL1CostLayer);
200
+
201
+ bool SmoothL1CostLayer::init (const LayerMap& layerMap,
202
+ const ParameterMap& parameterMap) {
203
+ return CostLayer::init (layerMap, parameterMap);
204
+ }
205
+
206
+ void SmoothL1CostLayer::forwardImp (Matrix& output,
207
+ Argument& label,
208
+ Matrix& target) {
209
+ MatrixPtr targetCpu, labelCpu, outputCpu;
210
+ if (useGpu_) {
211
+ Matrix::resizeOrCreate (
212
+ targetCpu, target.getHeight (), target.getWidth (), false , false );
213
+ Matrix::resizeOrCreate (
214
+ outputCpu, output.getHeight (), output.getWidth (), false , false );
215
+ Matrix::resizeOrCreate (labelCpu,
216
+ label.value ->getHeight (),
217
+ label.value ->getWidth (),
218
+ false ,
219
+ false );
220
+ targetCpu->copyFrom (target);
221
+ outputCpu->copyFrom (output);
222
+ labelCpu->copyFrom (*label.value );
223
+ targetCpu->smoothL1 (*outputCpu, *(labelCpu));
224
+ target.copyFrom (*targetCpu);
225
+ } else {
226
+ target.smoothL1 (output, *label.value );
227
+ }
228
+ }
229
+
230
+ void SmoothL1CostLayer::backwardImp (Matrix& output,
231
+ Argument& label,
232
+ Matrix& outputG) {
233
+ MatrixPtr outputGCpu, labelCpu, outputCpu;
234
+ if (useGpu_) {
235
+ Matrix::resizeOrCreate (
236
+ outputGCpu, outputG.getHeight (), outputG.getWidth (), false , false );
237
+ Matrix::resizeOrCreate (
238
+ outputCpu, output.getHeight (), output.getWidth (), false , false );
239
+ Matrix::resizeOrCreate (labelCpu,
240
+ label.value ->getHeight (),
241
+ label.value ->getWidth (),
242
+ false ,
243
+ false );
244
+ outputGCpu->copyFrom (outputG);
245
+ outputCpu->copyFrom (output);
246
+ labelCpu->copyFrom (*label.value );
247
+ outputGCpu->smoothL1Bp (*outputCpu, *labelCpu);
248
+ outputG.copyFrom (*outputGCpu);
249
+ } else {
250
+ outputG.smoothL1Bp (output, *label.value );
251
+ }
252
+ }
253
+
195
254
//
196
255
// class RankingCost
197
256
//
0 commit comments