Skip to content

Commit d7319c2

Browse files
authored
Merge pull request #5165 from NHZlX/add_dilation
Add dilation for exconv layer
2 parents 3e6f768 + f3818bd commit d7319c2

File tree

17 files changed

+299
-157
lines changed

17 files changed

+299
-157
lines changed

paddle/function/ConvOp.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class ConvFunctionBase : public FunctionBase {
6161
// function arguments
6262
strides_ = config.get<std::vector<size_t>>("strides");
6363
paddings_ = config.get<std::vector<size_t>>("paddings");
64+
dilations_ = config.get<std::vector<size_t>>("dilations");
6465
groups_ = config.get<size_t>("groups");
6566

6667
// number of inputs and outputs
@@ -118,6 +119,7 @@ class ConvFunctionBase : public FunctionBase {
118119

119120
std::vector<size_t> strides_;
120121
std::vector<size_t> paddings_;
122+
std::vector<size_t> dilations_;
121123

122124
/// Group size, refer to grouped convolution in
123125
/// Alex Krizhevsky's paper: when group=2, the first half of the
@@ -133,6 +135,10 @@ class ConvFunctionBase : public FunctionBase {
133135

134136
inline int paddingW() const { return paddings_[1]; }
135137

138+
inline int dilationH() const { return dilations_[0]; }
139+
140+
inline int dilationW() const { return dilations_[1]; }
141+
136142
// A temporary memory in convolution calculation.
137143
MemoryHandlePtr memory_;
138144

paddle/function/ConvOpTest.h

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -79,45 +79,59 @@ void Convolution(const std::string& conv1,
7979
if (outputChannels < inputChannels) continue;
8080
for (size_t stride : {1, 2}) {
8181
for (size_t padding : {0, 1}) {
82-
if (padding >= filterSize) break;
82+
for (size_t dilation : {1, 3}) {
83+
if (padding >= filterSize) break;
84+
size_t filterS = (filterSize - 1) * dilation + 1;
8385

84-
// NNPACK only supports stride = 1 if batchSize > 1
85-
if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") &&
86-
batchSize > 1 && stride > 1)
87-
break;
86+
if (inputSize + 2 * padding < filterS) break;
8887

89-
size_t outputSize =
90-
(inputSize - filterSize + 2 * padding + stride) / stride;
91-
VLOG(3) << " batchSize=" << batchSize
92-
<< " inputChannels=" << inputChannels
93-
<< " inputHeight=" << inputSize
94-
<< " inputWidth=" << inputSize
95-
<< " outputChannels=" << outputChannels
96-
<< " filterHeight=" << filterSize
97-
<< " filterWidth=" << filterSize
98-
<< " outputHeight=" << outputSize
99-
<< " outputWidth=" << outputSize << " stride=" << stride
100-
<< " padding=" << padding;
88+
if ((conv1 == "NaiveConv-CPU" || conv2 == "NaiveConv-CPU" ||
89+
conv1 == "NNPACKConv-CPU" ||
90+
conv2 == "NNPACKConv-CPU") &&
91+
dilation > 1)
92+
break;
10193

102-
std::vector<size_t> paddings = {padding, padding};
103-
std::vector<size_t> strides = {stride, stride};
104-
Compare2Function<DType1, DType2> test(
105-
conv1,
106-
conv2,
107-
FuncConfig()
108-
.set("paddings", paddings)
109-
.set("strides", strides)
110-
.set("groups", (size_t)1)
111-
.set("algo", (std::string) "auto"));
94+
// NNPACK only supports stride = 1 if batchSize > 1
95+
if ((conv1 == "NNPACKConv-CPU" ||
96+
conv2 == "NNPACKConv-CPU") &&
97+
batchSize > 1 && stride > 1)
98+
break;
11299

113-
TensorShape input{
114-
batchSize, inputChannels, inputSize, inputSize};
115-
TensorShape filter{
116-
outputChannels, inputChannels, filterSize, filterSize};
117-
TensorShape output{
118-
batchSize, outputChannels, outputSize, outputSize};
100+
size_t outputSize =
101+
(inputSize - filterS + 2 * padding + stride) / stride;
102+
VLOG(3) << " batchSize=" << batchSize
103+
<< " inputChannels=" << inputChannels
104+
<< " inputHeight=" << inputSize
105+
<< " inputWidth=" << inputSize
106+
<< " outputChannels=" << outputChannels
107+
<< " filterHeight=" << filterSize
108+
<< " filterWidth=" << filterSize
109+
<< " outputHeight=" << outputSize
110+
<< " outputWidth=" << outputSize
111+
<< " stride=" << stride << " padding=" << padding;
119112

120-
function(test, input, filter, output);
113+
std::vector<size_t> paddings = {padding, padding};
114+
std::vector<size_t> strides = {stride, stride};
115+
std::vector<size_t> dilations = {dilation, dilation};
116+
Compare2Function<DType1, DType2> test(
117+
conv1,
118+
conv2,
119+
FuncConfig()
120+
.set("paddings", paddings)
121+
.set("strides", strides)
122+
.set("dilations", dilations)
123+
.set("groups", (size_t)1)
124+
.set("algo", (std::string) "auto"));
125+
126+
TensorShape input{
127+
batchSize, inputChannels, inputSize, inputSize};
128+
TensorShape filter{
129+
outputChannels, inputChannels, filterSize, filterSize};
130+
TensorShape output{
131+
batchSize, outputChannels, outputSize, outputSize};
132+
133+
function(test, input, filter, output);
134+
}
121135
}
122136
}
123137
}
@@ -144,6 +158,7 @@ void Convolution2(const std::string& conv1,
144158
for (size_t outputChannels : {7}) {
145159
size_t stride = 1;
146160
size_t padding = 0;
161+
size_t dilation = 1;
147162
size_t outputHeight =
148163
(inputHeight - filterHeight + 2 * padding + stride) /
149164
stride;
@@ -162,13 +177,15 @@ void Convolution2(const std::string& conv1,
162177

163178
std::vector<size_t> paddings = {padding, padding};
164179
std::vector<size_t> strides = {stride, stride};
180+
std::vector<size_t> dilations = {dilation, dilation};
165181
Compare2Function<DType1, DType2> test(
166182
conv1,
167183
conv2,
168184
FuncConfig()
169185
.set("paddings", paddings)
170186
.set("strides", strides)
171187
.set("groups", (size_t)1)
188+
.set("dilations", dilations)
172189
.set("algo", (std::string) "auto"));
173190

174191
TensorShape input{
@@ -223,6 +240,7 @@ void DepthwiseConvolution(const std::string& conv1,
223240

224241
std::vector<size_t> paddings = {padding, padding};
225242
std::vector<size_t> strides = {stride, stride};
243+
std::vector<size_t> dilations = {1, 1};
226244
size_t groups = inputChannels;
227245
Compare2Function<DType1, DType2> test(
228246
conv1,
@@ -231,6 +249,7 @@ void DepthwiseConvolution(const std::string& conv1,
231249
.set("paddings", paddings)
232250
.set("strides", strides)
233251
.set("groups", groups)
252+
.set("dilations", dilations)
234253
.set("algo", (std::string) "auto"));
235254

236255
TensorShape input{

paddle/function/GemmConvOp.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ class GemmConvFunction : public ConvFunctionBase {
100100
strideH(),
101101
strideW(),
102102
paddingH(),
103-
paddingW());
103+
paddingW(),
104+
dilationH(),
105+
dilationW());
104106
} else {
105107
colData = inputData + g * inputOffset;
106108
}
@@ -223,7 +225,9 @@ class GemmConvGradInputFunction : public ConvFunctionBase {
223225
strideH(),
224226
strideW(),
225227
paddingH(),
226-
paddingW());
228+
paddingW(),
229+
dilationH(),
230+
dilationW());
227231
}
228232
}
229233
inputGrad += inputChannels * inputHeight * inputWidth;
@@ -310,7 +314,9 @@ class GemmConvGradFilterFunction : public ConvFunctionBase {
310314
strideH(),
311315
strideW(),
312316
paddingH(),
313-
paddingW());
317+
paddingW(),
318+
dilationH(),
319+
dilationW());
314320
} else {
315321
colData = inputData + g * inputOffset;
316322
}

paddle/function/Im2Col.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ class Im2ColFunctor {
7878
int strideHeight,
7979
int strideWidth,
8080
int paddingHeight,
81-
int paddingWidth);
81+
int paddingWidth,
82+
int dilationHeight = 1,
83+
int dilationWidth = 1);
8284
};
8385

8486
template <ColFormat Format, DeviceType Device, class T>
@@ -91,7 +93,9 @@ class Col2ImFunctor {
9193
int strideHeight,
9294
int strideWidth,
9395
int paddingHeight,
94-
int paddingWidth);
96+
int paddingWidth,
97+
int dilationHeight = 1,
98+
int dilationWidth = 1);
9599
};
96100

97101
} // namespace paddle

paddle/function/Im2ColOp.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> {
3131
int strideHeight,
3232
int strideWidth,
3333
int paddingHeight,
34-
int paddingWidth) {
34+
int paddingWidth,
35+
int dilationHeight,
36+
int dilationWidth) {
3537
int inputChannels = imShape[0];
3638
int inputHeight = imShape[1];
3739
int inputWidth = imShape[2];
@@ -47,8 +49,8 @@ class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> {
4749
int c_im = c / filterWidth / filterHeight;
4850
for (int h = 0; h < outputHeight; ++h) {
4951
for (int w = 0; w < outputWidth; ++w) {
50-
int imRowIdx = h * strideHeight + hOffset;
51-
int imColIdx = w * strideWidth + wOffset;
52+
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
53+
int imColIdx = w * strideWidth + wOffset * dilationWidth;
5254
if ((imRowIdx - paddingHeight) < 0 ||
5355
(imRowIdx - paddingHeight) >= inputHeight ||
5456
(imColIdx - paddingWidth) < 0 ||
@@ -81,7 +83,9 @@ class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> {
8183
int strideHeight,
8284
int strideWidth,
8385
int paddingHeight,
84-
int paddingWidth) {
86+
int paddingWidth,
87+
int dilationHeight,
88+
int dilationWidth) {
8589
int inputChannels = imShape[0];
8690
int inputHeight = imShape[1];
8791
int inputWidth = imShape[2];
@@ -97,8 +101,8 @@ class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> {
97101
int c_im = c / filterWidth / filterHeight;
98102
for (int h = 0; h < outputHeight; ++h) {
99103
for (int w = 0; w < outputWidth; ++w) {
100-
int imRowIdx = h * strideHeight + hOffset;
101-
int imColIdx = w * strideWidth + wOffset;
104+
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
105+
int imColIdx = w * strideWidth + wOffset * dilationWidth;
102106
if ((imRowIdx - paddingHeight) >= 0 &&
103107
(imRowIdx - paddingHeight) < inputHeight &&
104108
(imColIdx - paddingWidth) >= 0 &&
@@ -134,7 +138,9 @@ class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> {
134138
int strideHeight,
135139
int strideWidth,
136140
int paddingHeight,
137-
int paddingWidth) {
141+
int paddingWidth,
142+
int dilationHeight = 1,
143+
int dilationWidth = 1) {
138144
int inputChannels = imShape[0];
139145
int inputHeight = imShape[1];
140146
int inputWidth = imShape[2];
@@ -147,9 +153,10 @@ class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> {
147153
for (int channel = 0; channel < inputChannels; ++channel) {
148154
for (int filterH = 0; filterH < filterHeight; ++filterH) {
149155
for (int filterW = 0; filterW < filterWidth; ++filterW) {
150-
int imRowOffset =
151-
outputH * strideHeight + filterH - paddingHeight;
152-
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
156+
int imRowOffset = outputH * strideHeight +
157+
filterH * dilationHeight - paddingHeight;
158+
int imColOffset = outputW * strideWidth +
159+
filterW * dilationWidth - paddingWidth;
153160
int colDataOffset =
154161
(((outputH * outputWidth + outputW) * inputChannels +
155162
channel) *
@@ -189,7 +196,9 @@ class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> {
189196
int strideHeight,
190197
int strideWidth,
191198
int paddingHeight,
192-
int paddingWidth) {
199+
int paddingWidth,
200+
int dilationHeight = 1,
201+
int dilationWidth = 1) {
193202
int inputChannels = imShape[0];
194203
int inputHeight = imShape[1];
195204
int inputWidth = imShape[2];
@@ -202,9 +211,10 @@ class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> {
202211
for (int channel = 0; channel < inputChannels; ++channel) {
203212
for (int filterH = 0; filterH < filterHeight; ++filterH) {
204213
for (int filterW = 0; filterW < filterWidth; ++filterW) {
205-
int imRowOffset =
206-
outputH * strideHeight + filterH - paddingHeight;
207-
int imColOffset = outputW * strideWidth + filterW - paddingWidth;
214+
int imRowOffset = outputH * strideHeight +
215+
filterH * dilationHeight - paddingHeight;
216+
int imColOffset = outputW * strideWidth +
217+
filterW * dilationWidth - paddingWidth;
208218
int colDataOffset =
209219
(((outputH * outputWidth + outputW) * inputChannels +
210220
channel) *

0 commit comments

Comments
 (0)