Skip to content

Commit 3133c09

Browse files
authored
Merge pull request #1768 from luotao1/pad
fix PadOp bug on Gpu
2 parents 87afc6d + a782759 commit 3133c09

File tree

3 files changed

+23
-49
lines changed

3 files changed

+23
-49
lines changed

paddle/function/Function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class FuncConfig {
3838
if (err) {
3939
*err = Error(e.what());
4040
} else {
41-
LOG(FATAL) << "Cannot get key " << key << "with error " << e.what();
41+
LOG(FATAL) << "Cannot get key " << key << " with error " << e.what();
4242
}
4343
return T();
4444
}

paddle/function/PadOpGpu.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ void Pad<DEVICE_TYPE_GPU>(real* outputs,
4444
size_t nth = num * inC * inH * inW;
4545
int blockSize = 1024;
4646
int gridSize = (nth + 1024 - 1) / 1024;
47-
int cstart = pad.channelStart, cend = pad.channelEnd;
48-
int hstart = pad.heightStart, hend = pad.heightEnd;
49-
int wstart = pad.widthStart, wend = pad.widthEnd;
47+
int cstart = pad.channel[0], cend = pad.channel[1];
48+
int hstart = pad.height[0], hend = pad.height[1];
49+
int wstart = pad.width[0], wend = pad.width[1];
5050
int outC = inC + cstart + cend;
5151
int outH = inH + hstart + hend;
5252
int outW = inW + wstart + wend;
@@ -83,9 +83,9 @@ void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
8383
int nth = num * inC * inH * inW;
8484
int blockSize = 1024;
8585
int gridSize = (nth + 1024 - 1) / 1024;
86-
int cstart = pad.channelStart, cend = pad.channelEnd;
87-
int hstart = pad.heightStart, hend = pad.heightEnd;
88-
int wstart = pad.widthStart, wend = pad.widthEnd;
86+
int cstart = pad.channel[0], cend = pad.channel[1];
87+
int hstart = pad.height[0], hend = pad.height[1];
88+
int wstart = pad.width[0], wend = pad.width[1];
8989
int outC = inC + cstart + cend;
9090
int outH = inH + hstart + hend;
9191
int outW = inW + wstart + wend;

paddle/function/PadOpTest.cpp

Lines changed: 16 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,48 +24,22 @@ TEST(Pad, real) {
2424
for (size_t imgSizeW : {5, 32, 96}) {
2525
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
2626
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
27-
28-
FunctionCompare compare("Pad",
29-
FuncConfig()
30-
.set("cstart", 2)
31-
.set("cend", 3)
32-
.set("hstart", 1)
33-
.set("hend", 2)
34-
.set("wstart", 3)
35-
.set("wend", 2));
36-
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
37-
TensorShape outDims{
38-
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
39-
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, inDims));
40-
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, outDims, ASSIGN_TO));
41-
compare.run();
42-
}
43-
}
44-
}
45-
}
46-
}
47-
48-
TEST(PadGrad, real) {
49-
for (size_t numSamples : {5, 32}) {
50-
for (size_t channels : {1, 5, 32}) {
51-
for (size_t imgSizeH : {5, 33, 100}) {
52-
for (size_t imgSizeW : {5, 32, 96}) {
53-
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
54-
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
55-
FunctionCompare compare("PadGrad",
56-
FuncConfig()
57-
.set("cstart", 2)
58-
.set("cend", 3)
59-
.set("hstart", 1)
60-
.set("hend", 2)
61-
.set("wstart", 3)
62-
.set("wend", 2));
63-
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
64-
TensorShape outDims{
65-
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
66-
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, outDims));
67-
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, inDims, ASSIGN_TO));
68-
compare.run();
27+
for (bool test_grad : {false, true}) {
28+
FunctionCompare compare(
29+
test_grad ? "PadGrad" : "Pad",
30+
FuncConfig()
31+
.set<std::vector<uint32_t>>("channel", {2, 3})
32+
.set<std::vector<uint32_t>>("height", {1, 2})
33+
.set<std::vector<uint32_t>>("width", {3, 2}));
34+
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
35+
TensorShape outDims{
36+
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
37+
compare.addInputs(
38+
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
39+
compare.addOutputs(BufferArg(
40+
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
41+
compare.run();
42+
}
6943
}
7044
}
7145
}

0 commit comments

Comments
 (0)