Skip to content

Commit df99b16

Browse files
authored
Merge pull request #9167 from kexinzhao/pool2d_fp16
Add float16 support for pool 2d operator
2 parents 3f5705c + dfec1df commit df99b16

File tree

6 files changed

+152
-102
lines changed

6 files changed

+152
-102
lines changed

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
2828
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
2929
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
3030
using DataLayout = platform::DataLayout;
31+
template <typename T>
32+
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
3133

3234
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
3335
static_cast<size_t>(1024) * 1024 * 1024;
@@ -134,8 +136,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
134136
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
135137
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
136138
// ------------------- cudnn conv forward ---------------------
137-
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
138-
beta = 0.0f;
139+
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
139140
for (int i = 0; i < groups; i++) {
140141
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
141142
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
@@ -282,8 +283,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
282283
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
283284
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
284285
// ------------------- cudnn conv backward data ---------------------
285-
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
286-
beta = 0.0f;
286+
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
287287
if (input_grad) {
288288
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
289289
// Because beta is zero, it is unnecessary to reset input_grad.

paddle/fluid/operators/pool_cudnn_op.cu.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
2424
using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
2525
using DataLayout = platform::DataLayout;
2626
using PoolingMode = platform::PoolingMode;
27+
template <typename T>
28+
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
2729

2830
template <typename T>
2931
class PoolCUDNNOpKernel : public framework::OpKernel<T> {
@@ -78,8 +80,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
7880

7981
// ------------------- cudnn pool algorithm ---------------------
8082
auto handle = ctx.cuda_device_context().cudnn_handle();
81-
T alpha = 1.0f, beta = 0.0f;
82-
83+
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
8384
PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward(
8485
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
8586
cudnn_output_desc, output_data));
@@ -144,8 +145,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
144145

145146
// ------------------- cudnn pool algorithm ---------------------
146147
auto handle = ctx.cuda_device_context().cudnn_handle();
147-
T alpha = 1.0f, beta = 0.0f;
148-
148+
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
149149
if (input_grad) {
150150
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
151151
// Because beta is zero, it is unnecessary to reset input_grad.
@@ -162,17 +162,19 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
162162
} // namespace paddle
163163

164164
namespace ops = paddle::operators;
165+
namespace plat = paddle::platform;
165166

166-
REGISTER_OP_KERNEL(pool2d, CUDNN, ::paddle::platform::CUDAPlace,
167+
REGISTER_OP_KERNEL(pool2d, CUDNN, plat::CUDAPlace,
167168
ops::PoolCUDNNOpKernel<float>,
168-
ops::PoolCUDNNOpKernel<double>);
169-
REGISTER_OP_KERNEL(pool2d_grad, CUDNN, ::paddle::platform::CUDAPlace,
169+
ops::PoolCUDNNOpKernel<double>,
170+
ops::PoolCUDNNOpKernel<plat::float16>);
171+
REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace,
170172
ops::PoolCUDNNGradOpKernel<float>,
171173
ops::PoolCUDNNGradOpKernel<double>);
172174

173-
REGISTER_OP_KERNEL(pool3d, CUDNN, ::paddle::platform::CUDAPlace,
175+
REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace,
174176
ops::PoolCUDNNOpKernel<float>,
175177
ops::PoolCUDNNOpKernel<double>);
176-
REGISTER_OP_KERNEL(pool3d_grad, CUDNN, ::paddle::platform::CUDAPlace,
178+
REGISTER_OP_KERNEL(pool3d_grad, CUDNN, plat::CUDAPlace,
177179
ops::PoolCUDNNGradOpKernel<float>,
178180
ops::PoolCUDNNGradOpKernel<double>);

paddle/fluid/operators/pool_op.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,15 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
124124
}
125125
#endif
126126

127+
auto input_data_type = framework::ToDataType(ctx.Input<Tensor>("X")->type());
128+
if (input_data_type == framework::proto::VarType::FP16) {
129+
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
130+
"float16 can only be used when CUDNN is used");
131+
}
127132
std::string data_format = ctx.Attr<std::string>("data_format");
128133
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
129-
return framework::OpKernelType(
130-
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
131-
layout_, library_);
134+
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
135+
library_);
132136
}
133137

134138
Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,9 @@ def np_dtype_to_fluid_dtype(input):
483483
input: input numpy array
484484
485485
Returns:
486-
input: if the dtype of input is np.float16, its dtype will be
487-
changed to np.uint16 so that the internal memory will be
488-
reinterpreted input as of dtype np.uint16.
486+
input: The dtype of input will be changed to np.uint16 if
487+
it is originally np.float16, such that the internal memory
488+
of input will be reinterpreted as of dtype np.uint16.
489489
"""
490490
if input.dtype == np.float16:
491491
input.dtype = np.uint16

python/paddle/fluid/tests/unittests/test_conv2d_op.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,13 @@ def conv2d_forward_naive(input, filter, group, conv_param):
6363

6464
class TestConv2dOp(OpTest):
6565
def setUp(self):
66+
self.op_type = "conv2d"
6667
self.use_cudnn = False
6768
self.use_mkldnn = False
68-
self.init_op_type()
69+
self.dtype = np.float32
70+
self.init_kernel_type()
6971
self.init_group()
7072
self.init_dilation()
71-
self.init_data_type()
7273
self.init_test_case()
7374

7475
conv2d_param = {
@@ -159,17 +160,14 @@ def init_test_case(self):
159160
f_c = self.input_size[1] / self.groups
160161
self.filter_size = [6, f_c, 3, 3]
161162

162-
def init_data_type(self):
163-
self.dtype = np.float32
164-
165163
def init_dilation(self):
166164
self.dilations = [1, 1]
167165

168166
def init_group(self):
169167
self.groups = 1
170168

171-
def init_op_type(self):
172-
self.op_type = "conv2d"
169+
def init_kernel_type(self):
170+
pass
173171

174172

175173
class TestWithPad(TestConv2dOp):
@@ -241,13 +239,13 @@ def init_group(self):
241239

242240
#----------------Conv2dCUDNN----------------
243241
class TestCUDNN(TestConv2dOp):
244-
def init_op_type(self):
242+
def init_kernel_type(self):
245243
self.use_cudnn = True
246-
self.op_type = "conv2d"
247244

248245

249-
class TestFP16CUDNN(TestCUDNN):
250-
def init_data_type(self):
246+
class TestFP16CUDNN(TestConv2dOp):
247+
def init_kernel_type(self):
248+
self.use_cudnn = True
251249
self.dtype = np.float16
252250

253251
def test_check_output(self):
@@ -258,13 +256,13 @@ def test_check_output(self):
258256

259257

260258
class TestCUDNNWithPad(TestWithPad):
261-
def init_op_type(self):
259+
def init_kernel_type(self):
262260
self.use_cudnn = True
263-
self.op_type = "conv2d"
264261

265262

266-
class TestFP16CUDNNWithPad(TestCUDNNWithPad):
267-
def init_data_type(self):
263+
class TestFP16CUDNNWithPad(TestWithPad):
264+
def init_kernel_type(self):
265+
self.use_cudnn = True
268266
self.dtype = np.float16
269267

270268
def test_check_output(self):
@@ -275,13 +273,13 @@ def test_check_output(self):
275273

276274

277275
class TestCUDNNWithStride(TestWithStride):
278-
def init_op_type(self):
276+
def init_kernel_type(self):
279277
self.use_cudnn = True
280-
self.op_type = "conv2d"
281278

282279

283-
class TestFP16CUDNNWithStride(TestCUDNNWithStride):
284-
def init_data_type(self):
280+
class TestFP16CUDNNWithStride(TestWithStride):
281+
def init_kernel_type(self):
282+
self.use_cudnn = True
285283
self.dtype = np.float16
286284

287285
def test_check_output(self):
@@ -292,13 +290,13 @@ def test_check_output(self):
292290

293291

294292
class TestCUDNNWithGroup(TestWithGroup):
295-
def init_op_type(self):
293+
def init_kernel_type(self):
296294
self.use_cudnn = True
297-
self.op_type = "conv2d"
298295

299296

300-
class TestFP16CUDNNWithGroup(TestCUDNNWithGroup):
301-
def init_data_type(self):
297+
class TestFP16CUDNNWithGroup(TestWithGroup):
298+
def init_kernel_type(self):
299+
self.use_cudnn = True
302300
self.dtype = np.float16
303301

304302
def test_check_output(self):
@@ -309,13 +307,13 @@ def test_check_output(self):
309307

310308

311309
class TestCUDNNWith1x1(TestWith1x1):
312-
def init_op_type(self):
310+
def init_kernel_type(self):
313311
self.use_cudnn = True
314-
self.op_type = "conv2d"
315312

316313

317-
class TestFP16CUDNNWith1x1(TestCUDNNWith1x1):
318-
def init_data_type(self):
314+
class TestFP16CUDNNWith1x1(TestWith1x1):
315+
def init_kernel_type(self):
316+
self.use_cudnn = True
319317
self.dtype = np.float16
320318

321319
def test_check_output(self):
@@ -326,13 +324,13 @@ def test_check_output(self):
326324

327325

328326
class TestCUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
329-
def init_op_type(self):
327+
def init_kernel_type(self):
330328
self.use_cudnn = True
331-
self.op_type = "conv2d"
332329

333330

334-
class TestFP16CUDNNWithInput1x1Filter1x1(TestCUDNNWithInput1x1Filter1x1):
335-
def init_data_type(self):
331+
class TestFP16CUDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
332+
def init_kernel_type(self):
333+
self.use_cudnn = True
336334
self.dtype = np.float16
337335

338336
def test_check_output(self):
@@ -375,21 +373,18 @@ def init_test_case(self):
375373

376374
#----------------Conv2dMKLDNN----------------
377375
class TestMKLDNN(TestConv2dOp):
378-
def init_op_type(self):
376+
def init_kernel_type(self):
379377
self.use_mkldnn = True
380-
self.op_type = "conv2d"
381378

382379

383380
class TestMKLDNNWithPad(TestWithPad):
384-
def init_op_type(self):
381+
def init_kernel_type(self):
385382
self.use_mkldnn = True
386-
self.op_type = "conv2d"
387383

388384

389385
class TestMKLDNNWithStride(TestWithStride):
390-
def init_op_type(self):
386+
def init_kernel_type(self):
391387
self.use_mkldnn = True
392-
self.op_type = "conv2d"
393388

394389

395390
if __name__ == '__main__':

0 commit comments

Comments
 (0)