@@ -196,8 +196,8 @@ class CrossMapNormalFunc : public FunctionBase {
196
196
}
197
197
198
198
void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
199
- CHECK_EQ (( size_t ) numInputs_, inputs.size ());
200
- CHECK_EQ (( size_t ) numOutputs_, outputs.size ());
199
+ CHECK_EQ (numInputs_, inputs.size ());
200
+ CHECK_EQ (numOutputs_, outputs.size ());
201
201
202
202
CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
203
203
CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
@@ -215,7 +215,7 @@ class CrossMapNormalFunc : public FunctionBase {
215
215
216
216
// number of floating-point operations
217
217
// an approximate value
218
- size_t ops = batchSize * maps * (( rows * columns) * size_);
218
+ size_t ops = batchSize * maps * rows * columns * ( size_ * 2 + 3 );
219
219
220
220
return ops;
221
221
}
@@ -273,15 +273,7 @@ class CrossMapNormalGradFunc : public FunctionBase {
273
273
}
274
274
275
275
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
276
- CHECK_EQ ((size_t )numInputs_, inputs.size ());
277
- CHECK_EQ ((size_t )numOutputs_, outputs.size ());
278
-
279
- CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
280
- CHECK (inputs[0 ].shape () == inputs[1 ].shape ());
281
- CHECK (inputs[0 ].shape () == inputs[2 ].shape ());
282
- CHECK (inputs[0 ].shape () == inputs[3 ].shape ());
283
- CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
284
-
276
+ check (inputs, outputs);
285
277
if (outputs[0 ].getArgType () != ADD_TO) {
286
278
// Currently, some algorithm implementations are ASSIGN_TO mode,
287
279
// if need to support the ADD_TO calculation, need to clear the output.
@@ -290,25 +282,52 @@ class CrossMapNormalGradFunc : public FunctionBase {
290
282
tmp.zero ();
291
283
}
292
284
293
- size_t samples = inputs[0 ].shape ()[0 ];
294
- size_t channels = inputs[0 ].shape ()[1 ];
295
- size_t height = inputs[0 ].shape ()[2 ];
296
- size_t width = inputs[0 ].shape ()[3 ];
285
+ size_t batchSize = inputs[0 ].shape ()[0 ];
286
+ size_t maps = inputs[0 ].shape ()[1 ];
287
+ size_t rows = inputs[0 ].shape ()[2 ];
288
+ size_t columns = inputs[0 ].shape ()[3 ];
297
289
298
290
CrossMapNormalGrad<Device>(outputs[0 ].data <real>(),
299
291
inputs[0 ].data <real>(),
300
292
inputs[1 ].data <real>(),
301
293
inputs[2 ].data <real>(),
302
294
inputs[3 ].data <real>(),
303
- samples ,
304
- channels ,
305
- height ,
306
- width ,
295
+ batchSize ,
296
+ maps ,
297
+ rows ,
298
+ columns ,
307
299
size_,
308
300
scale_,
309
301
pow_);
310
302
}
311
303
304
+ void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
305
+ CHECK_EQ (numInputs_, inputs.size ());
306
+ CHECK_EQ (numOutputs_, outputs.size ());
307
+
308
+ CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
309
+ CHECK (inputs[0 ].shape () == inputs[1 ].shape ());
310
+ CHECK (inputs[0 ].shape () == inputs[2 ].shape ());
311
+ CHECK (inputs[0 ].shape () == inputs[3 ].shape ());
312
+ CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
313
+ }
314
+
315
+ // Only need the shape of one input, can calculate the
316
+ // floating-point operation.
317
+ size_t ops (const BufferArgs& inputs, const BufferArgs& outputs) override {
318
+ CHECK_LT ((size_t )1 , inputs.size ());
319
+ size_t batchSize = inputs[0 ].shape ()[0 ];
320
+ size_t maps = inputs[0 ].shape ()[1 ];
321
+ size_t rows = inputs[0 ].shape ()[2 ];
322
+ size_t columns = inputs[0 ].shape ()[3 ];
323
+
324
+ // number of floating-point operations
325
+ // an approximate value
326
+ size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2 );
327
+
328
+ return ops;
329
+ }
330
+
312
331
private:
313
332
size_t size_;
314
333
real scale_;
0 commit comments