Skip to content

Commit c5330fa

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into Add_conv3d_transpose_cudnn_op
2 parents 6fb4bb8 + 313f845 commit c5330fa

File tree

246 files changed

+4593
-1657
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

246 files changed

+4593
-1657
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@ third_party/
2121
cmake-build-*
2222

2323
# generated while compiling
24-
python/paddle/v2/framework/core.so
24+
python/paddle/v2/fluid/core.so
2525
paddle/pybind/pybind.h
2626
CMakeFiles
2727
cmake_install.cmake
2828
paddle/.timestamp
2929
python/paddlepaddle.egg-info/
3030
paddle/pybind/pybind.h
31-
python/paddle/v2/framework/tests/tmp/*

doc/design/evaluator.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
## Evaluator Design
2+
3+
### The Problem
4+
5+
During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted.
6+
7+
### Evaluator Design
8+
Currently, every operation is expressed in the graph. we divide the evaluator process into three steps.
9+
10+
1. Initialize the metric state and add it into the block.
11+
12+
2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once.
13+
14+
15+
3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices.
16+
17+
### Implementation
18+
This design is shown in python API.
19+
Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass.
20+
21+
22+
```python
23+
class Evaluator(object):
24+
"""
25+
Evaluator Base class.
26+
"""
27+
def __init__(self, name, **kwargs):
28+
"""
29+
Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts.
30+
Auc need four variables, `true_positives`,
31+
`true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program
32+
33+
The initialization of Evaluator should be responsible for:
34+
create metric states and append to the main_program
35+
"""
36+
pass
37+
38+
def _update_ops(self, input, label, **kwargs)
39+
"""
40+
Add mini-batch evaluator caculate operators to the main_program.
41+
Add increment operator to accumulate the metric states.
42+
"""
43+
44+
45+
def reset(self, executor, reset_program=None):
46+
"""
47+
Reset metric states at the begin of each pass/user specified batch number.
48+
Execute the reset_program to reset the states.
49+
"""
50+
51+
52+
def eval(self, executor, eval_program=None):
53+
"""
54+
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
55+
Execute the eval_program and return the result.
56+
"""
57+
return eval_result
58+
```

paddle/capi/Matrix.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ paddle_error paddle_matrix_get_shape(paddle_matrix mat,
121121

122122
paddle_matrix paddle_matrix_create_sparse(
123123
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu) {
124+
#ifndef PADDLE_MOBILE_INFERENCE
124125
auto ptr = new paddle::capi::CMatrix();
125126
ptr->mat = paddle::Matrix::createSparseMatrix(
126127
height,
@@ -131,6 +132,9 @@ paddle_matrix paddle_matrix_create_sparse(
131132
false,
132133
useGpu);
133134
return ptr;
135+
#else
136+
return nullptr;
137+
#endif
134138
}
135139

136140
paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
@@ -140,6 +144,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
140144
uint64_t colSize,
141145
float* valueArray,
142146
uint64_t valueSize) {
147+
#ifndef PADDLE_MOBILE_INFERENCE
143148
if (mat == nullptr) return kPD_NULLPTR;
144149
auto ptr = cast(mat);
145150
if (rowArray == nullptr || colArray == nullptr ||
@@ -160,4 +165,7 @@ paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
160165
} else {
161166
return kPD_NOT_SUPPORTED;
162167
}
168+
#else
169+
return kPD_NOT_SUPPORTED;
170+
#endif
163171
}

paddle/capi/matrix.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ PD_API paddle_matrix paddle_matrix_create(uint64_t height,
4848
* @param isBinary is binary (either 1 or 0 in matrix) or not.
4949
* @param useGpu is using GPU or not.
5050
* @return paddle_matrix.
51+
* @note Mobile inference does not support this interface.
5152
*/
5253
PD_API paddle_matrix paddle_matrix_create_sparse(
5354
uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu);
@@ -129,6 +130,7 @@ PD_API paddle_error paddle_matrix_get_shape(paddle_matrix mat,
129130
* NULL if the matrix is binary.
130131
* @param [in] valueSize length of value array. Zero if the matrix is binary.
131132
* @return paddle_error
133+
* @note Mobile inference does not support this interface.
132134
*/
133135
PD_API paddle_error paddle_matrix_sparse_copy_from(paddle_matrix mat,
134136
int* rowArray,

paddle/cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ if(WITH_GPU)
2727
set_source_files_properties(${CUDA_CXX_SOURCES}
2828
PROPERTIES COMPILE_FLAGS "-D__NVCC__")
2929
else()
30+
if (NOT MOBILE_INFERENCE)
3031
set(CUDA_CXX_SOURCES src/hl_warpctc_wrap.cc)
32+
endif()
3133
endif()
3234

3335
set(CUDA_CU_SOURCES

paddle/cuda/include/hl_cnn.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License. */
1818
#include "hl_base.h"
1919

2020
/**
21-
* @brief Maximum pool forward.
21+
* @brief Maximum pool forward with Mask output.
2222
*
2323
* @param[in] frameCnt batch size of input image.
2424
* @param[in] inputData input data.
@@ -35,7 +35,7 @@ limitations under the License. */
3535
* @param[in] paddingW padding width.
3636
* @param[out] tgtData output data.
3737
* @param[in] tgtStride stride between output data samples.
38-
*
38+
* @param[out] maskData the location indices of select max data.
3939
*/
4040
extern void hl_maxpool_forward(const int frameCnt,
4141
const real* inputData,
@@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
5151
const int paddingH,
5252
const int paddingW,
5353
real* tgtData,
54-
const int tgtStride);
54+
const int tgtStride,
55+
real* maskData = NULL);
5556

5657
/**
5758
* @brief Maximum pool backward.

paddle/cuda/include/stub/hl_cnn_stub.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
3131
const int paddingH,
3232
const int paddingW,
3333
real* tgtData,
34-
const int tgtStride) {}
34+
const int tgtStride,
35+
real* MaskData) {}
3536

3637
inline void hl_maxpool_backward(const int frameCnt,
3738
const real* inputData,

paddle/cuda/src/hl_cuda_cnn.cu

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
3131
const int offsetH,
3232
const int offsetW,
3333
real* tgtData,
34-
const int tgtStride) {
34+
const int tgtStride,
35+
real* maskData) {
3536
int index = blockIdx.x * blockDim.x + threadIdx.x;
3637
if (index < nthreads) {
3738
int pw = index % pooledW;
@@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
4546
hstart = max(hstart, 0);
4647
wstart = max(wstart, 0);
4748
real maxval = -FLT_MAX;
49+
int max_index = -1;
4850
inputData += (frameNum * channels + c) * height * width;
4951
for (int h = hstart; h < hend; ++h) {
5052
for (int w = wstart; w < wend; ++w) {
51-
if (maxval < inputData[h * width + w])
52-
maxval = inputData[h * width + w];
53+
if (maxval < inputData[h * width + w]) {
54+
max_index = h * width + w;
55+
maxval = inputData[max_index];
56+
}
5357
}
5458
}
5559
int tgtIndex =
5660
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
5761
tgtData[tgtIndex] = maxval;
62+
if (maskData != NULL) {
63+
maskData[tgtIndex] = max_index;
64+
}
5865
}
5966
}
6067

@@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
7279
const int paddingH,
7380
const int paddingW,
7481
real* tgtData,
75-
const int tgtStride) {
82+
const int tgtStride,
83+
real* maskData) {
7684
int num_kernels = pooledH * pooledW * channels * frameCnt;
7785
int blocks = (num_kernels + 1024 - 1) / 1024;
7886
dim3 threads(1024, 1);
@@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
92100
paddingH,
93101
paddingW,
94102
tgtData,
95-
tgtStride);
103+
tgtStride,
104+
maskData);
96105
CHECK_SYNC("hl_maxpool_forward failed");
97106
}
98107

paddle/framework/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ py_proto_compile(framework_py_proto SRCS framework.proto)
3838
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
3939
add_dependencies(framework_py_proto framework_py_proto_init)
4040
add_custom_command(TARGET framework_py_proto POST_BUILD
41-
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto
42-
COMMAND cp *.py ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto/
43-
COMMENT "Copy generated python proto into directory paddle/v2/framework/proto."
41+
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/proto
42+
COMMAND cp *.py ${PADDLE_SOURCE_DIR}/python/paddle/v2/fluid/proto/
43+
COMMENT "Copy generated python proto into directory paddle/v2/fluid/proto."
4444
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
4545

4646
cc_library(backward SRCS backward.cc DEPS net_op)

paddle/framework/backward.cc

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,12 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
377377
return grad_op_descs;
378378
}
379379

380+
static BlockDescBind* CreateStepBlock(
381+
ProgramDescBind& program_desc,
382+
std::unordered_set<std::string>* no_grad_vars,
383+
std::unordered_map<std::string, std::string>* grad_to_var,
384+
int step_block_idx);
385+
380386
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
381387
ProgramDescBind& program_desc, int block_idx,
382388
std::unordered_set<std::string>* no_grad_vars,
@@ -392,13 +398,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
392398

393399
if ((*it)->Type() == "recurrent") {
394400
int step_block_idx = (*it)->GetBlockAttr("step_block");
395-
auto backward_block_op_descs = MakeBlockBackward(
396-
program_desc, step_block_idx, no_grad_vars, grad_to_var);
401+
BlockDescBind* backward_block = CreateStepBlock(
402+
program_desc, no_grad_vars, grad_to_var, step_block_idx);
403+
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
404+
} else if ((*it)->Type() == "conditional_block") {
397405
BlockDescBind* backward_block =
398-
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
399-
for (auto& ptr : backward_block_op_descs) {
400-
backward_block->AppendAllocatedOp(std::move(ptr));
401-
}
406+
CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
407+
(*it)->GetBlockAttr("block"));
402408
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
403409
} else {
404410
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
@@ -449,6 +455,21 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
449455
return backward_descs;
450456
}
451457

458+
static BlockDescBind* CreateStepBlock(
459+
ProgramDescBind& program_desc,
460+
std::unordered_set<std::string>* no_grad_vars,
461+
std::unordered_map<std::string, std::string>* grad_to_var,
462+
int step_block_idx) {
463+
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
464+
no_grad_vars, grad_to_var);
465+
BlockDescBind* backward_block =
466+
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
467+
for (auto& ptr : backward_block_op_descs) {
468+
backward_block->AppendAllocatedOp(move(ptr));
469+
}
470+
return backward_block;
471+
}
472+
452473
ParamGradInfoMap AppendBackward(
453474
ProgramDescBind& program_desc, const VarDescBind& target,
454475
const std::unordered_set<std::string>& no_grad_vars) {

0 commit comments

Comments
 (0)