@@ -173,13 +173,9 @@ class CrossMapNormalFunc : public FunctionBase {
173
173
}
174
174
175
175
void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
176
- CHECK_EQ ((size_t )numInputs_, inputs.size ());
177
- CHECK_EQ ((size_t )numOutputs_, outputs.size ());
178
-
179
- CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
180
- CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
181
- CHECK (inputs[0 ].shape () == outputs[1 ].shape ());
182
-
176
+ check (inputs, outputs);
177
+ // ArgType check still on here,
178
+ // not sure whether it is better to put inside the check.
183
179
CHECK_EQ (outputs[0 ].getArgType (), ASSIGN_TO);
184
180
CHECK_EQ (outputs[1 ].getArgType (), ASSIGN_TO);
185
181
size_t batchSize = inputs[0 ].shape ()[0 ];
@@ -199,6 +195,15 @@ class CrossMapNormalFunc : public FunctionBase {
199
195
pow_);
200
196
}
201
197
198
+ void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
199
+ CHECK_EQ ((size_t )numInputs_, inputs.size ());
200
+ CHECK_EQ ((size_t )numOutputs_, outputs.size ());
201
+
202
+ CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
203
+ CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
204
+ CHECK (inputs[0 ].shape () == outputs[1 ].shape ());
205
+ }
206
+
202
207
// Only need the shape of the input, can calculate the
203
208
// floating-point operation.
204
209
size_t ops (const BufferArgs& inputs, const BufferArgs& outputs) override {
@@ -211,6 +216,8 @@ class CrossMapNormalFunc : public FunctionBase {
211
216
// number of floating-point operations
212
217
// an approximate value
213
218
size_t ops = batchSize * maps * ((rows * columns) * size_);
219
+
220
+ return ops;
214
221
}
215
222
216
223
private:
0 commit comments