@@ -162,38 +162,64 @@ template <DeviceType Device>
162
162
class CrossMapNormalFunc : public FunctionBase {
163
163
public:
164
164
void init (const FuncConfig& config) override {
165
+ // function arguments
165
166
size_ = config.get <size_t >(" size" );
166
167
scale_ = config.get <real>(" scale" );
167
168
pow_ = config.get <real>(" pow" );
169
+
170
+ // number of inputs and outputs
171
+ numInputs_ = 1 ;
172
+ numOutputs_ = 2 ;
168
173
}
169
174
170
175
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
171
- CHECK_EQ ((size_t )1 , inputs.size ());
172
- CHECK_EQ ((size_t )2 , outputs.size ());
173
-
174
- CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
175
- CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
176
- CHECK (inputs[0 ].shape () == outputs[1 ].shape ());
177
-
176
+ check (inputs, outputs);
177
+ // ArgType check still on here,
178
+ // not sure whether it is better to put inside the check.
178
179
CHECK_EQ (outputs[0 ].getArgType (), ASSIGN_TO);
179
180
CHECK_EQ (outputs[1 ].getArgType (), ASSIGN_TO);
180
- size_t samples = inputs[0 ].shape ()[0 ];
181
- size_t channels = inputs[0 ].shape ()[1 ];
182
- size_t height = inputs[0 ].shape ()[2 ];
183
- size_t width = inputs[0 ].shape ()[3 ];
181
+ size_t batchSize = inputs[0 ].shape ()[0 ];
182
+ size_t maps = inputs[0 ].shape ()[1 ];
183
+ size_t rows = inputs[0 ].shape ()[2 ];
184
+ size_t columns = inputs[0 ].shape ()[3 ];
184
185
185
186
CrossMapNormal<Device>(outputs[0 ].data <real>(),
186
187
outputs[1 ].data <real>(),
187
188
inputs[0 ].data <real>(),
188
- samples ,
189
- channels ,
190
- height ,
191
- width ,
189
+ batchSize ,
190
+ maps ,
191
+ rows ,
192
+ columns ,
192
193
size_,
193
194
scale_,
194
195
pow_);
195
196
}
196
197
198
+ void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
199
+ CHECK_EQ (numInputs_, inputs.size ());
200
+ CHECK_EQ (numOutputs_, outputs.size ());
201
+
202
+ CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
203
+ CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
204
+ CHECK (inputs[0 ].shape () == outputs[1 ].shape ());
205
+ }
206
+
207
+ // Only need the shape of the input, can calculate the
208
+ // floating-point operation.
209
+ size_t ops (const BufferArgs& inputs, const BufferArgs& outputs) override {
210
+ CHECK_EQ ((size_t )numInputs_, inputs.size ());
211
+ size_t batchSize = inputs[0 ].shape ()[0 ];
212
+ size_t maps = inputs[0 ].shape ()[1 ];
213
+ size_t rows = inputs[0 ].shape ()[2 ];
214
+ size_t columns = inputs[0 ].shape ()[3 ];
215
+
216
+ // number of floating-point operations
217
+ // an approximate value
218
+ size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3 );
219
+
220
+ return ops;
221
+ }
222
+
197
223
private:
198
224
size_t size_;
199
225
real scale_;
@@ -236,21 +262,18 @@ template <DeviceType Device>
236
262
class CrossMapNormalGradFunc : public FunctionBase {
237
263
public:
238
264
void init (const FuncConfig& config) override {
265
+ // function arguments
239
266
size_ = config.get <size_t >(" size" );
240
267
scale_ = config.get <real>(" scale" );
241
268
pow_ = config.get <real>(" pow" );
269
+
270
+ // number of inputs and outputs
271
+ numInputs_ = 4 ;
272
+ numOutputs_ = 1 ;
242
273
}
243
274
244
275
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
245
- CHECK_EQ ((size_t )4 , inputs.size ());
246
- CHECK_EQ ((size_t )1 , outputs.size ());
247
-
248
- CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
249
- CHECK (inputs[0 ].shape () == inputs[1 ].shape ());
250
- CHECK (inputs[0 ].shape () == inputs[2 ].shape ());
251
- CHECK (inputs[0 ].shape () == inputs[3 ].shape ());
252
- CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
253
-
276
+ check (inputs, outputs);
254
277
if (outputs[0 ].getArgType () != ADD_TO) {
255
278
// Currently, some algorithm implementations are ASSIGN_TO mode,
256
279
// if need to support the ADD_TO calculation, need to clear the output.
@@ -259,25 +282,52 @@ class CrossMapNormalGradFunc : public FunctionBase {
259
282
tmp.zero ();
260
283
}
261
284
262
- size_t samples = inputs[0 ].shape ()[0 ];
263
- size_t channels = inputs[0 ].shape ()[1 ];
264
- size_t height = inputs[0 ].shape ()[2 ];
265
- size_t width = inputs[0 ].shape ()[3 ];
285
+ size_t batchSize = inputs[0 ].shape ()[0 ];
286
+ size_t maps = inputs[0 ].shape ()[1 ];
287
+ size_t rows = inputs[0 ].shape ()[2 ];
288
+ size_t columns = inputs[0 ].shape ()[3 ];
266
289
267
290
CrossMapNormalGrad<Device>(outputs[0 ].data <real>(),
268
291
inputs[0 ].data <real>(),
269
292
inputs[1 ].data <real>(),
270
293
inputs[2 ].data <real>(),
271
294
inputs[3 ].data <real>(),
272
- samples ,
273
- channels ,
274
- height ,
275
- width ,
295
+ batchSize ,
296
+ maps ,
297
+ rows ,
298
+ columns ,
276
299
size_,
277
300
scale_,
278
301
pow_);
279
302
}
280
303
304
+ void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
305
+ CHECK_EQ (numInputs_, inputs.size ());
306
+ CHECK_EQ (numOutputs_, outputs.size ());
307
+
308
+ CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
309
+ CHECK (inputs[0 ].shape () == inputs[1 ].shape ());
310
+ CHECK (inputs[0 ].shape () == inputs[2 ].shape ());
311
+ CHECK (inputs[0 ].shape () == inputs[3 ].shape ());
312
+ CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
313
+ }
314
+
315
+ // Only need the shape of one input, can calculate the
316
+ // floating-point operation.
317
+ size_t ops (const BufferArgs& inputs, const BufferArgs& outputs) override {
318
+ CHECK_LT ((size_t )1 , inputs.size ());
319
+ size_t batchSize = inputs[0 ].shape ()[0 ];
320
+ size_t maps = inputs[0 ].shape ()[1 ];
321
+ size_t rows = inputs[0 ].shape ()[2 ];
322
+ size_t columns = inputs[0 ].shape ()[3 ];
323
+
324
+ // number of floating-point operations
325
+ // an approximate value
326
+ size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2 );
327
+
328
+ return ops;
329
+ }
330
+
281
331
private:
282
332
size_t size_;
283
333
real scale_;
0 commit comments