@@ -182,23 +182,37 @@ class CrossMapNormalFunc : public FunctionBase {
182
182
183
183
CHECK_EQ (outputs[0 ].getArgType (), ASSIGN_TO);
184
184
CHECK_EQ (outputs[1 ].getArgType (), ASSIGN_TO);
185
- size_t samples = inputs[0 ].shape ()[0 ];
186
- size_t channels = inputs[0 ].shape ()[1 ];
187
- size_t height = inputs[0 ].shape ()[2 ];
188
- size_t width = inputs[0 ].shape ()[3 ];
185
+ size_t batchSize = inputs[0 ].shape ()[0 ];
186
+ size_t maps = inputs[0 ].shape ()[1 ];
187
+ size_t rows = inputs[0 ].shape ()[2 ];
188
+ size_t columns = inputs[0 ].shape ()[3 ];
189
189
190
190
CrossMapNormal<Device>(outputs[0 ].data <real>(),
191
191
outputs[1 ].data <real>(),
192
192
inputs[0 ].data <real>(),
193
- samples ,
194
- channels ,
195
- height ,
196
- width ,
193
+ batchSize ,
194
+ maps ,
195
+ rows ,
196
+ columns ,
197
197
size_,
198
198
scale_,
199
199
pow_);
200
200
}
201
201
202
+ // Only need the shape of the input, can calculate the
203
+ // floating-point operation.
204
+ size_t ops (const BufferArgs& inputs, const BufferArgs& outputs) override {
205
+ CHECK_EQ ((size_t )numInputs_, inputs.size ());
206
+ size_t batchSize = inputs[0 ].shape ()[0 ];
207
+ size_t maps = inputs[0 ].shape ()[1 ];
208
+ size_t rows = inputs[0 ].shape ()[2 ];
209
+ size_t columns = inputs[0 ].shape ()[3 ];
210
+
211
+ // number of floating-point operations
212
+ // an approximate value
213
+ size_t ops = batchSize * maps * ((rows * columns) * size_);
214
+ }
215
+
202
216
private:
203
217
size_t size_;
204
218
real scale_;
0 commit comments