Skip to content

Commit a13ec34

Browse files
committed
fix test error
1 parent e4de5dc commit a13ec34

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
134134
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
135135
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
136136
// ------------------- cudnn conv forward ---------------------
137-
T alpha = static_cast<T>(1.0f);
138-
T beta = static_cast<T>(0.0f);
137+
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f,
138+
beta = 0.0f;
139139
for (int i = 0; i < groups; i++) {
140140
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
141141
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
@@ -321,7 +321,7 @@ namespace plat = paddle::platform;
321321
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
322322
paddle::operators::CUDNNConvOpKernel<float>,
323323
paddle::operators::CUDNNConvOpKernel<double>,
324-
paddle::operators::CUDNNConvOpKernel < plat::float16);
324+
paddle::operators::CUDNNConvOpKernel<plat::float16>);
325325
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
326326
paddle::operators::CUDNNConvGradOpKernel<float>,
327327
paddle::operators::CUDNNConvGradOpKernel<double>);

paddle/fluid/platform/cudnn_helper.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,14 @@ template <>
8585
class CudnnDataType<float16> {
8686
public:
8787
static const cudnnDataType_t type = CUDNN_DATA_HALF;
88-
typedef const float16 ScalingParamType;
88+
// The scaling param type is float for HALF and FLOAT tensors
89+
typedef const float ScalingParamType;
8990
static ScalingParamType* kOne() {
90-
static ScalingParamType v = static_cast<float16>(1.0);
91+
static ScalingParamType v = 1.0;
9192
return &v;
9293
}
9394
static ScalingParamType* kZero() {
94-
static ScalingParamType v = static_cast<float16>(0.0);
95+
static ScalingParamType v = 0.0;
9596
return &v;
9697
}
9798
};

python/paddle/fluid/tests/unittests/test_conv2d_op.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def setUp(self):
7979

8080
input = np.random.random(self.input_size).astype(self.dtype)
8181
filter = np.random.random(self.filter_size).astype(self.dtype)
82-
output = conv2d_forward_naive(self.input, self.filter, self.groups,
82+
output = conv2d_forward_naive(input, filter, self.groups,
8383
conv2d_param).astype(self.dtype)
8484

8585
# numpy float16 is binded to paddle::platform::float16
@@ -88,9 +88,12 @@ def setUp(self):
8888
# uint16_t in paddle or np.uint16 in numpy, which are
8989
# themselves binded together.
9090
self.inputs = {
91-
'Input': input.view(np.uint16)
92-
if self.dtype == np.float16 else input,
93-
'Filter': create_view(filter)
91+
#'Input': (input.view(np.uint16)
92+
# if self.dtype == np.float16 else input),
93+
#'Filter': (filter.view(np.uint16)
94+
# if self.dtype == np.float16 else filter)
95+
'Input': OpTest.create_view(input),
96+
'Filter': OpTest.create_view(filter)
9497
}
9598
self.attrs = {
9699
'strides': self.stride,
@@ -254,7 +257,7 @@ def test_check_output(self):
254257
if core.is_compiled_with_cuda():
255258
place = core.CUDAPlace(0)
256259
if core.is_float16_supported(place):
257-
self.check_output_with_place(place, atol=1e-1)
260+
self.check_output_with_place(place, atol=2e-2)
258261

259262
def test_check_grad(self):
260263
pass

0 commit comments

Comments
 (0)