Skip to content

Commit 9896f15

Browse files
committed
Add FunctionBase::ops()
1 parent 225a8fa commit 9896f15

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

paddle/function/CrossMapNormalOp.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,23 +182,37 @@ class CrossMapNormalFunc : public FunctionBase {
182182

183183
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
184184
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
185-
size_t samples = inputs[0].shape()[0];
186-
size_t channels = inputs[0].shape()[1];
187-
size_t height = inputs[0].shape()[2];
188-
size_t width = inputs[0].shape()[3];
185+
size_t batchSize = inputs[0].shape()[0];
186+
size_t maps = inputs[0].shape()[1];
187+
size_t rows = inputs[0].shape()[2];
188+
size_t columns = inputs[0].shape()[3];
189189

190190
CrossMapNormal<Device>(outputs[0].data<real>(),
191191
outputs[1].data<real>(),
192192
inputs[0].data<real>(),
193-
samples,
194-
channels,
195-
height,
196-
width,
193+
batchSize,
194+
maps,
195+
rows,
196+
columns,
197197
size_,
198198
scale_,
199199
pow_);
200200
}
201201

202+
// Only need the shape of the input, can calculate the
203+
// floating-point operation.
204+
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
205+
CHECK_EQ((size_t)numInputs_, inputs.size());
206+
size_t batchSize = inputs[0].shape()[0];
207+
size_t maps = inputs[0].shape()[1];
208+
size_t rows = inputs[0].shape()[2];
209+
size_t columns = inputs[0].shape()[3];
210+
211+
// number of floating-point operations
212+
// an approximate value
213+
size_t ops = batchSize * maps * ((rows * columns) * size_);
214+
}
215+
202216
private:
203217
size_t size_;
204218
real scale_;

paddle/function/Function.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ class FunctionBase {
153153

154154
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
155155

156+
// Calculate the number of floating-point operations of this Function.
157+
// The inputs and outputs arguments do not need to contain the actual data,
158+
// only the shape.
159+
virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
160+
return 0;
161+
}
162+
156163
int getNumInputs() const { return numInputs_; }
157164

158165
int getNumOutputs() const { return numOutputs_; }

0 commit comments

Comments
 (0)