Skip to content

Commit 13ec6f9

Browse files
committed
Merge remote-tracking branch 'upstream/develop' into factorization_machine_layer
2 parents 6fed6f2 + d5be1d4 commit 13ec6f9

33 files changed

+872
-392
lines changed

benchmark/paddle/image/googlenet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
width = 224
66
num_class = 1000
77
batch_size = get_config_arg('batch_size', int, 128)
8+
use_gpu = get_config_arg('use_gpu', bool, True)
89

910
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
1011
define_py_data_sources2(
@@ -16,6 +17,8 @@
1617
learning_method=MomentumOptimizer(0.9),
1718
regularization=L2Regularization(0.0005 * batch_size))
1819

20+
conv_projection = conv_projection if use_gpu else img_conv_layer
21+
1922
def inception2(name, input, channels, \
2023
filter1,
2124
filter3R, filter3,
@@ -138,7 +141,7 @@ def inception(name, input, channels, \
138141
cat = concat_layer(
139142
name=name,
140143
input=[cov1, cov3, cov5, covprj],
141-
bias_attr=True,
144+
bias_attr=True if use_gpu else False,
142145
act=ReluActivation())
143146
return cat
144147

benchmark/paddle/image/run_mkldnn.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ fi
4040
for use_mkldnn in True False; do
4141
for batchsize in 64 128 256; do
4242
train vgg 19 $batchsize $use_mkldnn
43-
train resnet 50 $batchsize $use_mkldnn
43+
train resnet 50 $batchsize $use_mkldnn
44+
train googlenet v1 $batchsize $use_mkldnn
4445
done
4546
done

paddle/gserver/activations/ActivationFunction.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,37 @@ Error __must_check backward(Argument& act) {
212212
}
213213
END_DEFINE_ACTIVATION(sequence_softmax)
214214

215+
/*
216+
* @brief SoftSign Activation.
217+
* \f[
218+
* f(z) = \frac{z}{1 + |z|}
219+
* \f]
220+
*/
221+
BEGIN_DEFINE_ACTIVATION(softsign)
222+
private:
223+
MatrixPtr denominator_;
224+
225+
Error __must_check forward(Argument& act) {
226+
size_t height = act.value->getHeight();
227+
size_t width = act.value->getWidth();
228+
Matrix::resizeOrCreate(
229+
denominator_, height, width, false, useGpu(act.deviceId));
230+
denominator_->assign(*act.value);
231+
denominator_->abs2();
232+
denominator_->add(1.);
233+
234+
act.value->dotDiv(*act.value, *denominator_);
235+
return Error();
236+
}
237+
238+
Error __must_check backward(Argument& act) {
239+
denominator_->square2();
240+
denominator_->scalarDiv(*denominator_, 1.);
241+
act.grad->dotMul(*act.grad, *denominator_);
242+
return Error();
243+
}
244+
END_DEFINE_ACTIVATION(softsign)
245+
215246
/**
216247
* @brief Relu Activation.
217248
* forward. y = max(0, z)

paddle/operators/conv_cudnn_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ REGISTER_OP(conv_cudnn, ops::ConvOp, ops::CudnnConvOpMaker, conv_cudnn_grad,
4040
ops::ConvOpGrad);
4141

4242
REGISTER_OP_CPU_KERNEL(conv_cudnn,
43-
ops::GemmConvKernel<paddle::platform::CPUPlace, float>);
43+
ops::GemmConvKernel<paddle::platform::CPUPlace, float>,
44+
ops::GemmConvKernel<paddle::platform::CPUPlace, double>);
4445
REGISTER_OP_CPU_KERNEL(
45-
conv_cudnn_grad,
46-
ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>);
46+
conv_cudnn_grad, ops::GemmConvGradKernel<paddle::platform::CPUPlace, float>,
47+
ops::GemmConvGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_cudnn_op.cu.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
259259
} // namespace operators
260260
} // namespace paddle
261261

262-
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>);
262+
REGISTER_OP_GPU_KERNEL(conv_cudnn, paddle::operators::CudnnConvOpKernel<float>,
263+
paddle::operators::CudnnConvOpKernel<double>);
263264
REGISTER_OP_GPU_KERNEL(conv_cudnn_grad,
264-
paddle::operators::CudnnConvGradOpKernel<float>);
265+
paddle::operators::CudnnConvGradOpKernel<float>,
266+
paddle::operators::CudnnConvGradOpKernel<double>);

paddle/operators/conv_transpose_cudnn_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,22 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
6161

6262
REGISTER_OP_CPU_KERNEL(
6363
conv2d_transpose_cudnn,
64-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
64+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
65+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
6566
REGISTER_OP_CPU_KERNEL(
6667
conv2d_transpose_cudnn_grad,
67-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
68+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
69+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);
6870

6971
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
7072
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
7173
ops::ConvTransposeOpGrad);
7274

7375
REGISTER_OP_CPU_KERNEL(
7476
conv3d_transpose_cudnn,
75-
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
77+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>,
78+
ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, double>);
7679
REGISTER_OP_CPU_KERNEL(
7780
conv3d_transpose_cudnn_grad,
78-
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
81+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>,
82+
ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, double>);

paddle/operators/conv_transpose_cudnn_op.cu.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,15 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
235235
namespace ops = paddle::operators;
236236

237237
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn,
238-
ops::CudnnConvTransposeOpKernel<float>);
238+
ops::CudnnConvTransposeOpKernel<float>,
239+
ops::CudnnConvTransposeOpKernel<double>);
239240
REGISTER_OP_GPU_KERNEL(conv2d_transpose_cudnn_grad,
240-
ops::CudnnConvTransposeGradOpKernel<float>);
241+
ops::CudnnConvTransposeGradOpKernel<float>,
242+
ops::CudnnConvTransposeGradOpKernel<double>);
241243

242244
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn,
243-
ops::CudnnConvTransposeOpKernel<float>);
245+
ops::CudnnConvTransposeOpKernel<float>,
246+
ops::CudnnConvTransposeOpKernel<double>);
244247
REGISTER_OP_GPU_KERNEL(conv3d_transpose_cudnn_grad,
245-
ops::CudnnConvTransposeGradOpKernel<float>);
248+
ops::CudnnConvTransposeGradOpKernel<float>,
249+
ops::CudnnConvTransposeGradOpKernel<double>);

paddle/operators/math/pooling.cc

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,8 @@ template class Pool3dGradFunctor<
498498
* Ksize, strides, paddings are two elements. These two elements represent
499499
* height and width, respectively.
500500
*/
501-
template <typename T>
502-
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
501+
template <typename T1, typename T2>
502+
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T1, T2> {
503503
public:
504504
void operator()(const platform::DeviceContext& context,
505505
const framework::Tensor& input, std::vector<int>& ksize,
@@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
520520
const int input_stride = input_height * input_width;
521521
const int output_stride = output_height * output_width;
522522

523-
const T* input_data = input.data<T>();
524-
T* output_data = output->mutable_data<T>(context.GetPlace());
525-
T* mask_data = mask->mutable_data<T>(context.GetPlace());
523+
const T1* input_data = input.data<T1>();
524+
T1* output_data = output->mutable_data<T1>(context.GetPlace());
525+
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
526526

527527
for (int i = 0; i < batch_size; i++) {
528528
for (int c = 0; c < output_channels; ++c) {
@@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
535535
int wend = std::min(wstart + ksize_width, input_width);
536536
wstart = std::max(wstart, 0);
537537

538-
T ele = static_cast<T>(-FLT_MAX);
538+
T1 ele = static_cast<T1>(-FLT_MAX);
539539
int index = -1;
540540
for (int h = hstart; h < hend; ++h) {
541541
for (int w = wstart; w < wend; ++w) {
@@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
563563
* Ksize, strides, paddings are two elements. These two elements represent
564564
* height and width, respectively.
565565
*/
566-
template <typename T>
567-
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
566+
template <typename T1, typename T2>
567+
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
568568
public:
569569
void operator()(const platform::DeviceContext& context,
570570
const framework::Tensor& output_grad,
@@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
580580
const int input_stride = input_height * input_width;
581581
const int output_stride = output_height * output_width;
582582

583-
const T* mask_data = mask.data<T>();
584-
const T* output_grad_data = output_grad.data<T>();
585-
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
583+
const T2* mask_data = mask.data<T2>();
584+
const T1* output_grad_data = output_grad.data<T1>();
585+
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
586586

587587
for (int n = 0; n < batch_size; ++n) {
588588
for (int c = 0; c < output_channels; ++c) {
@@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
602602
}
603603
};
604604

605-
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
606-
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
607-
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
608-
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
605+
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float, int>;
606+
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float, int>;
607+
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double, int>;
608+
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double, int>;
609609

610610
/*
611611
* All tensors are in NCDHW format.
612612
* Ksize, strides, paddings are three elements. These three elements represent
613613
* depth, height and width, respectively.
614614
*/
615-
template <typename T>
616-
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
615+
template <typename T1, typename T2>
616+
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T1, T2> {
617617
public:
618618
void operator()(const platform::DeviceContext& context,
619619
const framework::Tensor& input, std::vector<int>& ksize,
@@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
639639
const int input_stride = input_depth * input_height * input_width;
640640
const int output_stride = output_depth * output_height * output_width;
641641

642-
const T* input_data = input.data<T>();
643-
T* output_data = output->mutable_data<T>(context.GetPlace());
644-
T* mask_data = mask->mutable_data<T>(context.GetPlace());
642+
const T1* input_data = input.data<T1>();
643+
T1* output_data = output->mutable_data<T1>(context.GetPlace());
644+
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
645645

646646
for (int i = 0; i < batch_size; i++) {
647647
for (int c = 0; c < output_channels; ++c) {
@@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
659659
wstart = std::max(wstart, 0);
660660

661661
int output_idx = (pd * output_height + ph) * output_width + pw;
662-
T ele = static_cast<T>(-FLT_MAX);
662+
T1 ele = static_cast<T1>(-FLT_MAX);
663663
int index = -1;
664664
for (int d = dstart; d < dend; ++d) {
665665
for (int h = hstart; h < hend; ++h) {
@@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
691691
* Ksize, strides, paddings are three elements. These three elements represent
692692
* depth, height and width, respectively.
693693
*/
694-
template <typename T>
695-
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
694+
template <typename T1, typename T2>
695+
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
696696
public:
697697
void operator()(const platform::DeviceContext& context,
698698
const framework::Tensor& output_grad,
@@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
710710
const int input_stride = input_depth * input_height * input_width;
711711
const int output_stride = output_depth * output_height * output_width;
712712

713-
const T* mask_data = mask.data<T>();
714-
const T* output_grad_data = output_grad.data<T>();
715-
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
713+
const T2* mask_data = mask.data<T2>();
714+
const T1* output_grad_data = output_grad.data<T1>();
715+
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
716716

717717
for (int n = 0; n < batch_size; ++n) {
718718
for (int c = 0; c < output_channels; ++c) {
@@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
735735
}
736736
};
737737

738-
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
739-
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
740-
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
741-
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
738+
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float, int>;
739+
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float, int>;
740+
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double, int>;
741+
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double, int>;
742742
} // namespace math
743743
} // namespace operators
744744
} // namespace paddle

0 commit comments

Comments
 (0)