@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
-
16
15
#include " paddle/math/Matrix.h"
16
+ #include " paddle/math/MathUtils.h"
17
17
#include " Operator.h"
18
18
19
19
namespace paddle {
@@ -35,8 +35,8 @@ class ConvOperator : public Operator {
35
35
*/
36
36
virtual ~ConvOperator () {
37
37
if (workSpaceInBytes_ != 0 ) {
38
- hl_free_mem_device (workSpace_);
39
- workSpaceInBytes_ = 0 ;
38
+ hl_free_mem_device (workSpace_);
39
+ workSpaceInBytes_ = 0 ;
40
40
}
41
41
42
42
hl_destroy_tensor_descriptor (inputDesc_);
@@ -83,33 +83,6 @@ class ConvOperator : public Operator {
83
83
filterSize_ * filterSizeY_ * channels_ * numFilters_);
84
84
}
85
85
86
- /* *
87
- * Calculate output size.
88
- */
89
- int outputSize (int imageSize, int filterSize, int padding, int stride) {
90
- int outputSize;
91
- if (!caffeMode_) {
92
- /* input(+padding): 0123456789
93
- * imageSize(+padding) = 10;
94
- * filterSize = 3;
95
- * stride = 2;
96
- * output: (012), (234), (456), (678), (9)
97
- * outputSize = 5;
98
- */
99
- outputSize =
100
- (imageSize - filterSize + 2 * padding + stride - 1 ) / stride + 1 ;
101
- } else {
102
- /* input(+padding): 0123456789
103
- * imageSize(+padding) = 10;
104
- * filterSize = 3;
105
- * stride = 2;
106
- * output: (012), (234), (456), (678)
107
- * outputSize = 4;
108
- */
109
- outputSize = (imageSize - filterSize + 2 * padding) / stride + 1 ;
110
- }
111
- return outputSize;
112
- }
113
86
// / Most of member variables are same with CudnnConvLayer.
114
87
// / There is no explanation here.
115
88
int imageH_, imageW_, outputH_, outputW_;
@@ -129,7 +102,7 @@ class ConvOperator : public Operator {
129
102
int fwdAlgo_, bwdFilterAlgo_, bwdDataAlgo_;
130
103
size_t fwdLimitBytes_, bwdDataLimitBytes_, bwdFilterLimitBytes_;
131
104
size_t workSpaceInBytes_;
132
- void * workSpace_;
105
+ void * workSpace_;
133
106
bool isSelectAlgo_;
134
107
};
135
108
@@ -160,33 +133,32 @@ ConvOperator::ConvOperator(const OperatorConfig &config, bool useGpu)
160
133
void ConvOperator::allocConvWorkSpace (size_t maxWorkSpace) {
161
134
if (maxWorkSpace > workSpaceInBytes_) {
162
135
if (workSpaceInBytes_ != 0 ) {
163
- hl_free_mem_device (workSpace_);
136
+ hl_free_mem_device (workSpace_);
164
137
}
165
138
// total amount of storage needed
166
139
workSpace_ = hl_malloc_device (maxWorkSpace);
167
140
workSpaceInBytes_ = maxWorkSpace;
168
141
}
169
142
}
170
143
171
-
172
144
void ConvOperator::reshape (int batchSize) {
173
145
imageH_ = ins_[0 ]->getFrameHeight ();
174
146
imageW_ = ins_[0 ]->getFrameWidth ();
175
147
if (imageH_ == 0 ) imageH_ = imgSize_;
176
148
if (imageW_ == 0 ) imageW_ = imgSize_;
177
- outputH_ = outputSize (imageH_, filterSizeY_, paddingY_, strideY_);
178
- outputW_ = outputSize (imageW_, filterSize_, padding_, stride_);
149
+ outputH_ = outputSize (imageH_, filterSizeY_, paddingY_, strideY_, caffeMode_ );
150
+ outputW_ = outputSize (imageW_, filterSize_, padding_, stride_, caffeMode_ );
179
151
180
152
out_->setFrameHeight (outputH_);
181
153
out_->setFrameWidth (outputW_);
182
154
183
155
reshapeImageDescriptors ();
184
156
185
157
if (!isSelectAlgo_) {
186
- hl_conv_workspace (inputDesc_, outputDesc_, filterDesc_,
187
- convDesc_, &fwdAlgo_, &fwdLimitBytes_,
188
- &bwdDataAlgo_ , &bwdDataLimitBytes_ ,
189
- &bwdFilterAlgo_, &bwdFilterLimitBytes_);
158
+ hl_conv_workspace (inputDesc_, outputDesc_, filterDesc_, convDesc_,
159
+ &fwdAlgo_, &fwdLimitBytes_, &bwdDataAlgo_ ,
160
+ &bwdDataLimitBytes_ , &bwdFilterAlgo_ ,
161
+ &bwdFilterLimitBytes_);
190
162
191
163
size_t maxWorkSpace = 0 ;
192
164
maxWorkSpace = std::max (fwdLimitBytes_, bwdDataLimitBytes_);
@@ -202,7 +174,8 @@ void ConvOperator::computeConvSizes() {
202
174
hl_create_filter_descriptor (&filterDesc_, channels_, numFilters_,
203
175
filterSizeY_, filterSize_);
204
176
hl_create_tensor_descriptor (&inputDesc_);
205
- int outputX = outputSize (imgSize_, filterSize_, padding_, stride_);
177
+ int outputX =
178
+ outputSize (imgSize_, filterSize_, padding_, stride_, caffeMode_);
206
179
CHECK_EQ (outputX, outputX_);
207
180
hl_create_tensor_descriptor (&outputDesc_);
208
181
hl_create_convolution_descriptor (&convDesc_, inputDesc_, filterDesc_,
@@ -211,13 +184,13 @@ void ConvOperator::computeConvSizes() {
211
184
212
185
void ConvOperator::reshapeImageDescriptors () {
213
186
hl_tensor_reshape (inputDesc_, 1 , channels_, imageH_, imageW_,
214
- channels_ * imageH_ * imageW_, imageH_ * imageW_,
215
- imageW_, 1 );
187
+ channels_ * imageH_ * imageW_, imageH_ * imageW_, imageW_,
188
+ 1 );
216
189
hl_tensor_reshape (outputDesc_, 1 , numFilters_, outputH_, outputW_,
217
190
numFilters_ * outputH_ * outputW_, outputH_ * outputW_,
218
191
outputW_, 1 );
219
- hl_reset_convolution_descriptor (convDesc_, inputDesc_, filterDesc_,
220
- paddingY_, padding_, strideY_, stride_);
192
+ hl_reset_convolution_descriptor (convDesc_, inputDesc_, filterDesc_, paddingY_,
193
+ padding_, strideY_, stride_);
221
194
inputOffset_ = channels_ * imageH_ * imageW_;
222
195
outputOffset_ = numFilters_ * outputH_ * outputW_;
223
196
weightOffset_ = numFilters_ * channels_ * filterSize_ * filterSize_;
@@ -273,18 +246,17 @@ void ConvOperator::backward() {
273
246
real *weightGrad = ins_[1 ]->grad ->getData () + weightOffset_ * batchId;
274
247
hl_convolution_backward_filter (inputDesc_, inputData, outputDesc_,
275
248
outGrad, filterDesc_, weightGrad,
276
- convDesc_, workSpace_,
277
- workSpaceInBytes_, bwdFilterAlgo_);
249
+ convDesc_, workSpace_, workSpaceInBytes_,
250
+ bwdFilterAlgo_);
278
251
}
279
252
280
253
MatrixPtr preGrad = ins_[0 ]->grad ;
281
254
if (NULL != preGrad) {
282
255
real *inputGrad = preGrad->getData () + inputOffset_ * batchId;
283
256
real *wgtData = ins_[1 ]->value ->getData () + weightOffset_ * batchId;
284
- hl_convolution_backward_data (inputDesc_, inputGrad, outputDesc_,
285
- outGrad, filterDesc_, wgtData,
286
- convDesc_, workSpace_,
287
- workSpaceInBytes_, bwdDataAlgo_);
257
+ hl_convolution_backward_data (
258
+ inputDesc_, inputGrad, outputDesc_, outGrad, filterDesc_, wgtData,
259
+ convDesc_, workSpace_, workSpaceInBytes_, bwdDataAlgo_);
288
260
}
289
261
}
290
262
}
0 commit comments