Skip to content

Commit 9907350

Browse files
author
Yibing Liu
committed
Merge branch 'develop' of upstream into refine_rank_loss_op
2 parents da62d6c + 990818f commit 9907350

File tree

13 files changed

+699
-56
lines changed

13 files changed

+699
-56
lines changed

doc/design/float16.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,51 @@ The goal of float16 is to serve as a key for the executor to find and run the co
2828
- [Eigen](https://github.com/RLovelett/eigen) >= 3.3 supports float16 calculation on both GPU and CPU using the `Eigen::half` class. It is mostly useful for Nvidia GPUs because of the overloaded arithmetic operators using cuda intrinsics. It falls back to using software emulation on CPU for calculation and there is no special treatment to ARM processors.
2929
- [ARM compute library](https://github.com/ARM-software/ComputeLibrary) >= 17.02.01 supports NEON FP16 kernels (requires ARMv8.2-A CPU).
3030

31+
### CUDA version issue
32+
There are currently three versions of CUDA that supports `__half` data type, namely, CUDA 7.5, 8.0, and 9.0.
33+
CUDA 7.5 and 8.0 define `__half` as a simple struct that has a `uint16_t` data (see [`cuda_fp16.h`](https://github.com/ptillet/isaac/blob/9212ab5a3ddbe48f30ef373f9c1fb546804c7a8c/include/isaac/external/CUDA/cuda_fp16.h)) as follows:
34+
```
35+
typedef struct __align__(2) {
36+
unsigned short x;
37+
} __half;
38+
39+
typedef __half half;
40+
```
41+
This struct does not define any overloaded arithmetic operators. So you have to directly use `__hadd` instead of `+` to correctly add two half types:
42+
```
43+
__global__ void Add() {
44+
half a, b, c;
45+
c = __hadd(a, b); // correct
46+
c = a + b; // compiler error: no operator "+" matches these operands
47+
}
48+
```
49+
CUDA 9.0 provides a major update to the half data type. The related code can be found in the updated [`cuda_fp16.h`](https://github.com/ptillet/isaac/blob/master/include/isaac/external/CUDA/cuda_fp16.h) and the newly added [`cuda_fp16.hpp`](https://github.com/ptillet/isaac/blob/master/include/isaac/external/CUDA/cuda_fp16.hpp).
50+
51+
Essentially, CUDA 9.0 renames the original `__half` type in 7.5 and 8.0 as `__half_raw`, and defines a new `__half` class type that has constructors, conversion operators, and also provides overloaded arithmetic operators such as follows:
52+
```
53+
typedef struct __CUDA_ALIGN__(2) {
54+
unsigned short x;
55+
} __half_raw;
56+
57+
58+
struct __CUDA_ALIGN__(2) __half {
59+
protected:
60+
unsigned short __x;
61+
public:
62+
// constructors and conversion operators from/to
63+
// __half_raw and other built-in data types
64+
}
65+
66+
typedef __half half;
67+
68+
__device__ __forceinline__
69+
__half operator+(const __half &lh, const __half &rh) {
70+
return __hadd(lh, rh);
71+
}
72+
73+
// Other overloaded operators
74+
```
75+
This new design makes `c = a + b` work correctly for CUDA half data type.
3176

3277
## Implementation
3378
The float16 class holds a 16-bit `uint16_t` data internally.

doc/howto/optimization/cpu_profiling.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ cprofilev -a 0.0.0.0 -p 3214 -f profile.out main.py
7171
7272
```
7373

74-
可以看到最耗时的函数是C++端的`run`函数。这需要联合我们第二节`Python与C++混合代码的性能分析`来进行调优。而`sync_with_cpp`函数的总共耗时很长,每次调用的耗时也很长。于是我们可以点击`sync_with_cpp`的详细信息,了解其调用关系。
74+
可以看到最耗时的函数是C++端的`run`函数。这需要联合我们第二节`Python``C++`混合代码的性能分析来进行调优。而`sync_with_cpp`函数的总共耗时很长,每次调用的耗时也很长。于是我们可以点击`sync_with_cpp`的详细信息,了解其调用关系。
7575

7676
```text
7777
Called By:
@@ -121,7 +121,7 @@ python -m yep -v main.py
121121

122122
1. 编译时指定`-g`生成调试信息。使用cmake的话,可以将CMAKE_BUILD_TYPE指定为`RelWithDebInfo`
123123
2. 编译时一定要开启优化。单纯的`Debug`编译性能会和`-O2`或者`-O3`有非常大的差别。`Debug`模式下的性能测试是没有意义的。
124-
3. 运行性能分析的时候,先从单线程开始,再开启多线程,进而多机。毕竟如果单线程调试更容易。可以设置`OMP_NUM_THREADS=1`这个环境变量关闭openmp优化。
124+
3. 运行性能分析的时候,先从单线程开始,再开启多线程,进而多机。毕竟单线程调试更容易。可以设置`OMP_NUM_THREADS=1`这个环境变量关闭openmp优化。
125125

126126
### 查看性能分析文件
127127

paddle/gserver/tests/CMakeLists.txt

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ if(NOT WITH_DOUBLE AND NOT MOBILE_INFERENCE)
6262
endif()
6363

6464
if(NOT MOBILE_INFERENCE)
65-
################## test_Evaluator #######################
65+
################## test_Evaluator #######################
6666
add_unittest(test_Evaluator
6767
test_Evaluator.cpp)
6868

69-
############### test_RecurrentGradientMachine ###############
69+
############### test_RecurrentGradientMachine ###############
7070
# TODO(yuyang18): There is some bug in test_RecurrentGradientMachine
7171
# I will fix it.
7272
add_unittest_without_exec(test_RecurrentGradientMachine
@@ -77,7 +77,7 @@ if(NOT MOBILE_INFERENCE)
7777
${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine
7878
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle)
7979

80-
############### test_NetworkCompare ###############
80+
############### test_NetworkCompare ###############
8181
add_unittest_without_exec(test_NetworkCompare
8282
test_NetworkCompare.cpp)
8383
if(WITH_GPU)
@@ -89,34 +89,33 @@ if(NOT MOBILE_INFERENCE)
8989
COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=false
9090
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle)
9191
endif()
92-
endif()
9392

93+
################# test_CompareSparse ##################
94+
add_unittest_without_exec(test_CompareSparse
95+
test_CompareSparse.cpp)
96+
if(NOT ON_TRAVIS)
97+
add_test(NAME test_CompareSparse
98+
COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d
99+
${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests
100+
./.set_port.sh -p port -n 6
101+
${CMAKE_CURRENT_BINARY_DIR}/test_CompareSparse
102+
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)
103+
endif()
104+
105+
################ test_CompareTwoNets ######################
106+
add_unittest_without_exec(test_CompareTwoNets
107+
test_CompareTwoNets.cpp)
108+
add_test(NAME test_CompareTwoNets
109+
COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d
110+
${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests
111+
${CMAKE_CURRENT_BINARY_DIR}/test_CompareTwoNets
112+
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)
113+
endif()
94114

115+
################ test_PyDataProvider2 ######################
95116
add_unittest_without_exec(test_PyDataProvider2
96117
test_PyDataProvider2.cpp)
97-
98118
add_test(NAME test_PyDataProvider2
99119
COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/paddle/gserver/tests:${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_PyDataProvider2
100120
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle
101121
)
102-
103-
################# test_CompareSparse ##################
104-
add_unittest_without_exec(test_CompareSparse
105-
test_CompareSparse.cpp)
106-
if(NOT ON_TRAVIS)
107-
add_test(NAME test_CompareSparse
108-
COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d
109-
${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests
110-
./.set_port.sh -p port -n 6
111-
${CMAKE_CURRENT_BINARY_DIR}/test_CompareSparse
112-
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)
113-
endif()
114-
115-
################ test_CompareTwoNets ######################
116-
add_unittest_without_exec(test_CompareTwoNets
117-
test_CompareTwoNets.cpp)
118-
add_test(NAME test_CompareTwoNets
119-
COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d
120-
${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests
121-
${CMAKE_CURRENT_BINARY_DIR}/test_CompareTwoNets
122-
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/)

paddle/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ set(DEPS_OPS
191191
sum_op
192192
pool_op
193193
maxout_op
194+
unpool_op
194195
pool_with_index_op
195196
conv_op
196197
conv_transpose_op
@@ -235,6 +236,7 @@ op_library(adagrad_op DEPS selected_rows_functor)
235236
op_library(conv_op DEPS vol2col)
236237
op_library(pool_op DEPS pooling)
237238
op_library(maxout_op DEPS maxouting)
239+
op_library(unpool_op DEPS unpooling)
238240
op_library(pool_with_index_op DEPS pooling)
239241
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)
240242
op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op)

paddle/operators/math/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ if(WITH_GPU)
1313
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
1414
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
1515
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
16-
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
1716
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
17+
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
18+
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
1819
else()
1920
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
2021
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
@@ -26,8 +27,9 @@ else()
2627
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
2728
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
2829
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
29-
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
3030
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
31+
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
32+
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
3133
endif()
3234

3335
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)

paddle/operators/math/unpooling.cc

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/math/unpooling.h"
16+
namespace paddle {
17+
namespace operators {
18+
namespace math {
19+
template <typename T>
20+
class Unpool2dMaxFunctor<platform::CPUPlace, T> {
21+
public:
22+
void operator()(const platform::DeviceContext& context,
23+
const framework::Tensor& input,
24+
const framework::Tensor& indices, framework::Tensor* output) {
25+
const int batch_size = input.dims()[0];
26+
const int input_height = input.dims()[2];
27+
const int input_width = input.dims()[3];
28+
const int output_channels = output->dims()[1];
29+
const int output_height = output->dims()[2];
30+
const int output_width = output->dims()[3];
31+
int input_feasize = input_height * input_width;
32+
int output_feasize = output_height * output_width;
33+
const T* input_data = input.data<T>();
34+
const int* indices_data = indices.data<int>();
35+
T* output_data = output->mutable_data<T>(context.GetPlace());
36+
for (int b = 0; b < batch_size; ++b) {
37+
for (int c = 0; c < output_channels; ++c) {
38+
for (int i = 0; i < input_feasize; ++i) {
39+
int index = indices_data[i];
40+
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
41+
output_data[index] = input_data[i];
42+
}
43+
input_data += input_feasize;
44+
indices_data += input_feasize;
45+
output_data += output_feasize;
46+
}
47+
}
48+
}
49+
};
50+
template <class T>
51+
class Unpool2dMaxGradFunctor<platform::CPUPlace, T> {
52+
public:
53+
void operator()(const platform::DeviceContext& context,
54+
const framework::Tensor& input,
55+
const framework::Tensor& indices,
56+
const framework::Tensor& output,
57+
const framework::Tensor& output_grad,
58+
framework::Tensor* input_grad) {
59+
const int batch_size = input.dims()[0];
60+
const int input_height = input.dims()[2];
61+
const int input_width = input.dims()[3];
62+
const int output_channels = output.dims()[1];
63+
const int output_height = output.dims()[2];
64+
const int output_width = output.dims()[3];
65+
int input_feasize = input_height * input_width;
66+
int output_feasize = output_height * output_width;
67+
const int* indices_data = indices.data<int>();
68+
const T* output_grad_data = output_grad.data<T>();
69+
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
70+
71+
for (int b = 0; b < batch_size; ++b) {
72+
for (int c = 0; c < output_channels; ++c) {
73+
for (int i = 0; i < input_feasize; ++i) {
74+
int index = indices_data[i];
75+
PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!");
76+
input_grad_data[i] = output_grad_data[index];
77+
}
78+
input_grad_data += input_feasize;
79+
indices_data += input_feasize;
80+
output_grad_data += output_feasize;
81+
}
82+
}
83+
}
84+
};
85+
template class Unpool2dMaxGradFunctor<platform::CPUPlace, float>;
86+
template class Unpool2dMaxGradFunctor<platform::CPUPlace, double>;
87+
template class Unpool2dMaxFunctor<platform::CPUPlace, float>;
88+
template class Unpool2dMaxFunctor<platform::CPUPlace, double>;
89+
} // namespace math
90+
} // namespace operators
91+
} // namespace paddle

0 commit comments

Comments
 (0)