Skip to content

Commit 5edbe32

Browse files
authored
Merge pull request #1216 from hedaoyuan/cmrnorm
Function Adds some properties
2 parents db0df8f + a9228e2 commit 5edbe32

File tree

2 files changed

+111
-32
lines changed

2 files changed

+111
-32
lines changed

paddle/function/CrossMapNormalOp.cpp

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -162,38 +162,64 @@ template <DeviceType Device>
162162
class CrossMapNormalFunc : public FunctionBase {
163163
public:
164164
void init(const FuncConfig& config) override {
165+
// function arguments
165166
size_ = config.get<size_t>("size");
166167
scale_ = config.get<real>("scale");
167168
pow_ = config.get<real>("pow");
169+
170+
// number of inputs and outputs
171+
numInputs_ = 1;
172+
numOutputs_ = 2;
168173
}
169174

170175
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
171-
CHECK_EQ((size_t)1, inputs.size());
172-
CHECK_EQ((size_t)2, outputs.size());
173-
174-
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
175-
CHECK(inputs[0].shape() == outputs[0].shape());
176-
CHECK(inputs[0].shape() == outputs[1].shape());
177-
176+
check(inputs, outputs);
177+
// ArgType check still on here,
178+
// not sure whether it is better to put inside the check.
178179
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
179180
CHECK_EQ(outputs[1].getArgType(), ASSIGN_TO);
180-
size_t samples = inputs[0].shape()[0];
181-
size_t channels = inputs[0].shape()[1];
182-
size_t height = inputs[0].shape()[2];
183-
size_t width = inputs[0].shape()[3];
181+
size_t batchSize = inputs[0].shape()[0];
182+
size_t maps = inputs[0].shape()[1];
183+
size_t rows = inputs[0].shape()[2];
184+
size_t columns = inputs[0].shape()[3];
184185

185186
CrossMapNormal<Device>(outputs[0].data<real>(),
186187
outputs[1].data<real>(),
187188
inputs[0].data<real>(),
188-
samples,
189-
channels,
190-
height,
191-
width,
189+
batchSize,
190+
maps,
191+
rows,
192+
columns,
192193
size_,
193194
scale_,
194195
pow_);
195196
}
196197

198+
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
199+
CHECK_EQ(numInputs_, inputs.size());
200+
CHECK_EQ(numOutputs_, outputs.size());
201+
202+
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
203+
CHECK(inputs[0].shape() == outputs[0].shape());
204+
CHECK(inputs[0].shape() == outputs[1].shape());
205+
}
206+
207+
// Only need the shape of the input, can calculate the
208+
// floating-point operation.
209+
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
210+
CHECK_EQ((size_t)numInputs_, inputs.size());
211+
size_t batchSize = inputs[0].shape()[0];
212+
size_t maps = inputs[0].shape()[1];
213+
size_t rows = inputs[0].shape()[2];
214+
size_t columns = inputs[0].shape()[3];
215+
216+
// number of floating-point operations
217+
// an approximate value
218+
size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3);
219+
220+
return ops;
221+
}
222+
197223
private:
198224
size_t size_;
199225
real scale_;
@@ -236,21 +262,18 @@ template <DeviceType Device>
236262
class CrossMapNormalGradFunc : public FunctionBase {
237263
public:
238264
void init(const FuncConfig& config) override {
265+
// function arguments
239266
size_ = config.get<size_t>("size");
240267
scale_ = config.get<real>("scale");
241268
pow_ = config.get<real>("pow");
269+
270+
// number of inputs and outputs
271+
numInputs_ = 4;
272+
numOutputs_ = 1;
242273
}
243274

244275
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
245-
CHECK_EQ((size_t)4, inputs.size());
246-
CHECK_EQ((size_t)1, outputs.size());
247-
248-
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
249-
CHECK(inputs[0].shape() == inputs[1].shape());
250-
CHECK(inputs[0].shape() == inputs[2].shape());
251-
CHECK(inputs[0].shape() == inputs[3].shape());
252-
CHECK(inputs[0].shape() == outputs[0].shape());
253-
276+
check(inputs, outputs);
254277
if (outputs[0].getArgType() != ADD_TO) {
255278
// Currently, some algorithm implementations are ASSIGN_TO mode,
256279
// if need to support the ADD_TO calculation, need to clear the output.
@@ -259,25 +282,52 @@ class CrossMapNormalGradFunc : public FunctionBase {
259282
tmp.zero();
260283
}
261284

262-
size_t samples = inputs[0].shape()[0];
263-
size_t channels = inputs[0].shape()[1];
264-
size_t height = inputs[0].shape()[2];
265-
size_t width = inputs[0].shape()[3];
285+
size_t batchSize = inputs[0].shape()[0];
286+
size_t maps = inputs[0].shape()[1];
287+
size_t rows = inputs[0].shape()[2];
288+
size_t columns = inputs[0].shape()[3];
266289

267290
CrossMapNormalGrad<Device>(outputs[0].data<real>(),
268291
inputs[0].data<real>(),
269292
inputs[1].data<real>(),
270293
inputs[2].data<real>(),
271294
inputs[3].data<real>(),
272-
samples,
273-
channels,
274-
height,
275-
width,
295+
batchSize,
296+
maps,
297+
rows,
298+
columns,
276299
size_,
277300
scale_,
278301
pow_);
279302
}
280303

304+
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
305+
CHECK_EQ(numInputs_, inputs.size());
306+
CHECK_EQ(numOutputs_, outputs.size());
307+
308+
CHECK_EQ(inputs[0].shape().ndims(), (size_t)4);
309+
CHECK(inputs[0].shape() == inputs[1].shape());
310+
CHECK(inputs[0].shape() == inputs[2].shape());
311+
CHECK(inputs[0].shape() == inputs[3].shape());
312+
CHECK(inputs[0].shape() == outputs[0].shape());
313+
}
314+
315+
// Only need the shape of one input, can calculate the
316+
// floating-point operation.
317+
size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) override {
318+
CHECK_LT((size_t)1, inputs.size());
319+
size_t batchSize = inputs[0].shape()[0];
320+
size_t maps = inputs[0].shape()[1];
321+
size_t rows = inputs[0].shape()[2];
322+
size_t columns = inputs[0].shape()[3];
323+
324+
// number of floating-point operations
325+
// an approximate value
326+
size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2);
327+
328+
return ops;
329+
}
330+
281331
private:
282332
size_t size_;
283333
real scale_;

paddle/function/Function.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,36 @@ class FunctionBase {
153153

154154
virtual void calc(const BufferArgs& inputs, const BufferArgs& outputs) {}
155155

156+
// This member function is used to check whether the BufferType and shape of
157+
// the inputs and outputs arguments of the Function are correct.
158+
// General calc function which will call this check to do arguments check.
159+
// And before the calc called, the caller can also check their own arguments.
160+
virtual void check(const BufferArgs& inputs, const BufferArgs& outputs) {}
161+
162+
// Calculate the number of floating-point operations of this Function.
163+
// The inputs and outputs arguments do not need to contain the actual data,
164+
// only the shape.
165+
// And some Functions have the same input and output shapes,
166+
// so you may not need to enter the complete number of arguments.
167+
// But entering the full arguments is always correct for this interface.
168+
virtual size_t ops(const BufferArgs& inputs, const BufferArgs& outputs) {
169+
return 0;
170+
}
171+
172+
int getNumInputs() const { return numInputs_; }
173+
174+
int getNumOutputs() const { return numOutputs_; }
175+
156176
static ClassRegistrar<FunctionBase> funcRegistrar_;
177+
178+
protected:
179+
// numInputs_ and numOutputs_ represents the maximum
180+
// input and output supported by Function.
181+
// Some functions are optimized for input and output,
182+
// so when comparing the number of arguments, for these functions
183+
// inputs.size() <= numInputs_ or outputs.size() <= numOutputs_
184+
size_t numInputs_;
185+
size_t numOutputs_;
157186
};
158187

159188
#define FUNC_NAME(typeName, deviceName) #typeName "-" #deviceName

0 commit comments

Comments
 (0)