Skip to content

Commit a9228e2

Browse files
committed
Fix CrossMapNormalGradFunc
1 parent c4437fa commit a9228e2

File tree

2 files changed

+43
-21
lines changed

2 files changed

+43
-21
lines changed

paddle/function/CrossMapNormalOp.cpp

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ class CrossMapNormalFunc : public FunctionBase {
196196
}
197197

198198
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
199-
CHECK_EQ((size_t)numInputs_, inputs.size());
200-
CHECK_EQ((size_t)numOutputs_, outputs.size());
199+
CHECK_EQ(numInputs_, inputs.size());
200+
CHECK_EQ(numOutputs_, outputs.size());
201201

202202
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
203203
CHECK(inputs[0].shape() == outputs[0].shape());
@@ -215,7 +215,7 @@ class CrossMapNormalFunc : public FunctionBase {
215215

216216
// number of floating-point operations
217217
// an approximate value
218-
size_t ops = batchSize * maps * ((rows * columns) * size_);
218+
size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3);
219219

220220
return ops;
221221
}
@@ -273,15 +273,7 @@ class CrossMapNormalGradFunc : public FunctionBase {
273273
}
274274

275275
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
276-
CHECK_EQ((size_t)numInputs_, inputs.size());
277-
CHECK_EQ((size_t)numOutputs_, outputs.size());
278-
279-
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
280-
CHECK(inputs[0].shape() == inputs[1].shape());
281-
CHECK(inputs[0].shape() == inputs[2].shape());
282-
CHECK(inputs[0].shape() == inputs[3].shape());
283-
CHECK(inputs[0].shape() == outputs[0].shape());
284-
276+
check(inputs, outputs);
285277
if (outputs[0].getArgType() != ADD_TO) {
286278
// Currently, some algorithm implementations are ASSIGN_TO mode,
287279
// if need to support the ADD_TO calculation, need to clear the output.
@@ -290,25 +282,52 @@ class CrossMapNormalGradFunc : public FunctionBase {
290282
tmp.zero();
291283
}
292284

293-
size_t samples = inputs[0].shape()[0];
294-
size_t channels = inputs[0].shape()[1];
295-
size_t height = inputs[0].shape()[2];
296-
size_t width = inputs[0].shape()[3];
285+
size_t batchSize = inputs[0].shape()[0];
286+
size_t maps = inputs[0].shape()[1];
287+
size_t rows = inputs[0].shape()[2];
288+
size_t columns = inputs[0].shape()[3];
297289

298290
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
299291
inputs[0].data<real>(),
300292
inputs[1].data<real>(),
301293
inputs[2].data<real>(),
302294
inputs[3].data<real>(),
303-
samples,
304-
channels,
305-
height,
306-
width,
295+
batchSize,
296+
maps,
297+
rows,
298+
columns,
307299
size_,
308300
scale_,
309301
pow_);
310302
}
311303

304+
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
305+
CHECK_EQ(numInputs_, inputs.size());
306+
CHECK_EQ(numOutputs_, outputs.size());
307+
308+
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
309+
CHECK(inputs[0].shape() == inputs[1].shape());
310+
CHECK(inputs[0].shape() == inputs[2].shape());
311+
CHECK(inputs[0].shape() == inputs[3].shape());
312+
CHECK(inputs[0].shape() == outputs[0].shape());
313+
}
314+
315+
// Only need the shape of one input, can calculate the
316+
// floating-point operation.
317+
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
318+
CHECK_LT((size_t)1, inputs.size());
319+
size_t batchSize = inputs[0].shape()[0];
320+
size_t maps = inputs[0].shape()[1];
321+
size_t rows = inputs[0].shape()[2];
322+
size_t columns = inputs[0].shape()[3];
323+
324+
// number of floating-point operations
325+
// an approximate value
326+
size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2);
327+
328+
return ops;
329+
}
330+
312331
private:
313332
size_t size_;
314333
real scale_;

paddle/function/Function.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,15 @@ class FunctionBase {
156156
// This member function is used to check whether the BufferType and shape of
157157
// the inputs and outputs arguments of the Function are correct.
158158
// General calc function which will call this check to do arguments check.
159-
// Also before the call calc, the caller can also check their own arguments.
159+
// And before the calc called, the caller can also check their own arguments.
160160
virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {}
161161

162162
// Calculate the number of floating-point operations of this Function.
163163
// The inputs and outputs arguments do not need to contain the actual data,
164164
// only the shape.
165+
// And some Functions have the same input and output shapes,
166+
// so you may not need to enter the complete number of arguments.
167+
// But entering the full arguments is always correct for this interface.
165168
virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
166169
return 0;
167170
}

0 commit comments

Comments
 (0)