Skip to content

Commit c4437fa

Browse files
committed
Add FunctionBase::check()
1 parent 9896f15 commit c4437fa

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

paddle/function/CrossMapNormalOp.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,9 @@ class CrossMapNormalFunc : public FunctionBase {
173173
}
174174

175175
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
176-
CHECK_EQ((size_t)numInputs_, inputs.size());
177-
CHECK_EQ((size_t)numOutputs_, outputs.size());
178-
179-
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
180-
CHECK(inputs[0].shape() == outputs[0].shape());
181-
CHECK(inputs[0].shape() == outputs[1].shape());
182-
176+
check(inputs, outputs);
177+
// ArgType check still on here,
178+
// not sure whether it is better to put inside the check.
183179
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
184180
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
185181
size_t batchSize = inputs[0].shape()[0];
@@ -199,6 +195,15 @@ class CrossMapNormalFunc : public FunctionBase {
199195
pow_);
200196
}
201197

198+
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
199+
CHECK_EQ((size_t)numInputs_, inputs.size());
200+
CHECK_EQ((size_t)numOutputs_, outputs.size());
201+
202+
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
203+
CHECK(inputs[0].shape() == outputs[0].shape());
204+
CHECK(inputs[0].shape() == outputs[1].shape());
205+
}
206+
202207
// Only need the shape of the input, can calculate the
203208
// floating-point operation.
204209
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
@@ -211,6 +216,8 @@ class CrossMapNormalFunc : public FunctionBase {
211216
// number of floating-point operations
212217
// an approximate value
213218
size_t ops = batchSize * maps * ((rows * columns) * size_);
219+
220+
return ops;
214221
}
215222

216223
private:

paddle/function/Function.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ class FunctionBase {
153153

154154
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
155155

156+
// This member function is used to check whether the BufferType and shape of
157+
// the inputs and outputs arguments of the Function are correct.
158+
// General calc function which will call this check to do arguments check.
159+
// Also before the call calc, the caller can also check their own arguments.
160+
virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {}
161+
156162
// Calculate the number of floating-point operations of this Function.
157163
// The inputs and outputs arguments do not need to contain the actual data,
158164
// only the shape.

0 commit comments

Comments
 (0)