Skip to content

Commit 4c58da2

Browse files
author
chengduo
authored
Merge pull request #10367 from NHZlX/fix_maxpool_with_mask_layer
fix pool with mask layer bug
2 parents 2abcf37 + bc290b5 commit 4c58da2

File tree

1 file changed

+11
-17
lines changed

1 file changed

+11
-17
lines changed

paddle/math/Matrix.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,26 +2157,20 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
21572157
int wend = wstart + sizeX;
21582158
wstart = wstart < 0 ? 0 : wstart;
21592159
wend = wend < (int)imgSizeW ? wend : (int)imgSizeW;
2160-
if (maskData == NULL) {
2161-
real tmp = -(real)FLT_MAX;
2162-
for (int h = hstart; h < hend; ++h) {
2163-
for (int w = wstart; w < wend; ++w) {
2164-
tmp = tmp < inputData[h * imgSizeW + w]
2165-
? inputData[h * imgSizeW + w]
2166-
: tmp;
2167-
}
2168-
}
2169-
outData[ph * outputW + pw] = tmp;
2170-
} else {
2171-
for (int h = hstart; h < hend; ++h) {
2172-
for (int w = wstart; w < wend; ++w) {
2173-
if (outData[ph * outputW + pw] < inputData[h * imgSizeW + w]) {
2174-
outData[ph * outputW + pw] = inputData[h * imgSizeW + w];
2175-
maskData[ph * outputW + pw] = h * imgSizeW + w;
2176-
}
2160+
2161+
real maxval = -(real)FLT_MAX;
2162+
int max_index = -1;
2163+
for (int h = hstart; h < hend; ++h) {
2164+
for (int w = wstart; w < wend; ++w) {
2165+
if (maxval < inputData[h * imgSizeW + w]) {
2166+
maxval = inputData[h * imgSizeW + w];
2167+
max_index = h * imgSizeW + w;
21772168
}
21782169
}
21792170
}
2171+
2172+
outData[ph * outputW + pw] = maxval;
2173+
if (maskData != NULL) maskData[ph * outputW + pw] = max_index;
21802174
}
21812175
}
21822176
// compute offset

0 commit comments

Comments
 (0)