Skip to content

Commit 9f3eb91

Browse files
committed
fix neon depthwise conv bug
1 parent 580340e commit 9f3eb91

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

paddle/function/neon/NeonDepthwiseConv.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,16 @@ class NeonDepthwiseConvFunction : public ConvFunctionBase {
6666
float* inputPadding = inputData;
6767
int padInputHeight = inputHeight + 2 * paddingH();
6868
int padInputWidth = inputWidth + 2 * paddingW();
69-
if (paddingH() > 0 || paddingW() > 0) {
70-
int newSize = batchSize * inputChannels * padInputHeight * padInputWidth;
71-
resizeBuffer<Device>(newSize);
72-
inputPadding = reinterpret_cast<float*>(memory_->getBuf());
73-
neon::Padding<float>::run(inputData,
69+
int newSize = batchSize * (inputChannels + 1) * padInputHeight * padInputWidth;
70+
resizeBuffer<Device>(newSize);
71+
inputPadding = reinterpret_cast<float*>(memory_->getBuf());
72+
neon::Padding<float>::run(inputData,
7473
inputPadding,
7574
batchSize * inputChannels,
7675
inputHeight,
7776
inputWidth,
7877
padInputHeight,
7978
padInputWidth);
80-
}
8179

8280
std::function<void(
8381
const float*, const float*, int, int, int, int, int, int, float*)>

0 commit comments

Comments
 (0)