Skip to content

Commit d1d2100

Browse files
authored
Merge pull request #5331 from dzhwinter/feature/evaluator
Feature/evaluator
2 parents cf07f3e + b32faa0 commit d1d2100

File tree

14 files changed

+334
-141
lines changed

14 files changed

+334
-141
lines changed

doc/design/evaluator.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
## Evaluator Design
2+
3+
### The Problem
4+
5+
During training or serving, we provide the evaluation function to measure the model performance, e.g., accuracy, precision. In the operator based framework design, the data go through the network pipeline batch by batch. As a result, inside the operator, we only can calculate one minibatch metrics. We need to provide a mechanism to calculate the metrics for each N pass/batch the user wanted.
6+
7+
### Evaluator Design
8+
Currently, every operation is expressed in the graph. we divide the evaluator process into three steps.
9+
10+
1. Initialize the metric state and add it into the block.
11+
12+
2. Calculate the statistic of the metric state in every mini-batch. The single operator is only responsible for calculating necessary statistics for one mini-batch. For example, accuracy operator only calculate a minibatch data if run once.
13+
14+
15+
3. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. When it comes to distributed training/Multi-GPU training, aggregate the value from different devices.
16+
17+
### Implementation
18+
This design is shown in python API.
19+
Each metric operator need to caculate the metric statistic and return the batch aware states, Python side responsible for accumulate the states for each pass.
20+
21+
22+
```python
23+
class Evaluator(object):
24+
"""
25+
Evaluator Base class.
26+
"""
27+
def __init__(self, name, **kwargs):
28+
"""
29+
Different evaluator may has different metric states. E.g, Accuracy need two variables, total and right sample counts.
30+
Auc need four variables, `true_positives`,
31+
`true_negatives`, `false_positives` and `false_negatives`. So every evaluator should create its needed variables and append to main_program
32+
33+
The initialization of Evaluator should be responsible for:
34+
create metric states and append to the main_program
35+
"""
36+
pass
37+
38+
def _update_ops(self, input, label, **kwargs)
39+
"""
40+
Add mini-batch evaluator caculate operators to the main_program.
41+
Add increment operator to accumulate the metric states.
42+
"""
43+
44+
45+
def reset(self, executor, reset_program=None):
46+
"""
47+
Reset metric states at the begin of each pass/user specified batch number.
48+
Execute the reset_program to reset the states.
49+
"""
50+
51+
52+
def eval(self, executor, eval_program=None):
53+
"""
54+
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
55+
Execute the eval_program and return the result.
56+
"""
57+
return eval_result
58+
```

paddle/operators/accuracy_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class AccuracyOp : public framework::OperatorWithKernel {
3030
"Input (Label) of accuracy op should not be null.");
3131
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
3232
"Output (Accuracy) of AccuracyOp should not be null.");
33+
PADDLE_ENFORCE(ctx->HasOutput("Correct"),
34+
"Output (Correct) of AccuracyOp should not be null.");
35+
PADDLE_ENFORCE(ctx->HasOutput("Total"),
36+
"Output (Total) of AccuracyOp should not be null.");
3337

3438
auto inference_dim = ctx->GetInputDim("Out");
3539
auto label_dim = ctx->GetInputDim("Label");
@@ -43,6 +47,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
4347
" the same as label.");
4448

4549
ctx->SetOutputDim("Accuracy", {1});
50+
ctx->SetOutputDim("Correct", {1});
51+
ctx->SetOutputDim("Total", {1});
4652
ctx->ShareLoD("Out", /*->*/ "Accuracy");
4753
}
4854

@@ -66,6 +72,8 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
6672
AddInput("Label", "Label of the training data");
6773
// TODO(typhoonzero): AddInput("Weight", ...
6874
AddOutput("Accuracy", "The accuracy of current batch");
75+
AddOutput("Correct", "The correct samples count of current batch");
76+
AddOutput("Total", "The samples count of current batch");
6977

7078
AddComment(R"DOC(
7179
Accuracy Operator.

paddle/operators/accuracy_op.cu

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS;
2424
template <int BlockSize>
2525
__global__ void AccuracyCudaKernel(const int N, const int D,
2626
const int64_t* Xdata,
27-
const int64_t* labeldata, float* accuracy) {
27+
const int64_t* labeldata, int* correct_data,
28+
float* accuracy) {
2829
int count = 0;
2930
__shared__ int total[BlockSize];
3031

@@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
4344
// reduce the count with init value 0, and output accuracy.
4445
int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
4546
if (threadIdx.x == 0) {
47+
*correct_data = result;
4648
*accuracy = static_cast<float>(result) / static_cast<float>(N);
4749
}
4850
}
@@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
5658
auto* inference = ctx.Input<Tensor>("Out");
5759
auto* indices = ctx.Input<Tensor>("Indices");
5860
auto* label = ctx.Input<Tensor>("Label");
61+
5962
auto* accuracy = ctx.Output<Tensor>("Accuracy");
63+
auto* correct = ctx.Output<Tensor>("Correct");
64+
auto* total = ctx.Output<Tensor>("Total");
6065
// FIXME(typhoonzero): only support indices currently
6166
// if add support for output values, how to detect the data type?
6267
const int64_t* indices_data = indices->data<int64_t>();
6368
const int64_t* label_data = label->data<int64_t>();
69+
70+
int* correct_data = correct->mutable_data<int>(ctx.GetPlace());
71+
int* total_data = total->mutable_data<int>(ctx.GetPlace());
6472
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
6573

66-
size_t num_samples = inference->dims()[0];
74+
int num_samples = static_cast<int>(inference->dims()[0]);
6775
size_t infer_width = inference->dims()[1];
6876
PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
77+
// cudaMemset((void**)&correct_data, 0, sizeof(float));
6978

7079
if (num_samples == 0) {
7180
return;
7281
}
82+
cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice);
7383

7484
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
7585
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
76-
num_samples, infer_width, indices_data, label_data, accuracy_data);
86+
num_samples, infer_width, indices_data, label_data, correct_data,
87+
accuracy_data);
88+
89+
int d_num_samples, d_num_correct;
90+
float d_accuracy;
91+
cudaMemcpy(&d_num_correct, correct_data, sizeof(int),
92+
cudaMemcpyDeviceToHost);
93+
cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost);
94+
cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float),
95+
cudaMemcpyDeviceToHost);
7796
}
7897
};
7998

8099
} // namespace operators
81100
} // namespace paddle
82101

83-
// FIXME(typhoonzero): types of T is for infernece data.
84-
// label data is always int
102+
// FIXME(typhoonzero): types of T is for inference data.
103+
// label data is always int64
85104
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
86105
paddle::operators::AccuracyOpCUDAKernel<double>);

paddle/operators/accuracy_op.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ class AccuracyKernel : public framework::OpKernel<T> {
2929
auto* indices = ctx.Input<Tensor>("Indices");
3030
auto* label = ctx.Input<Tensor>("Label");
3131
auto* accuracy = ctx.Output<Tensor>("Accuracy");
32+
auto* correct = ctx.Output<Tensor>("Correct");
33+
auto* total = ctx.Output<Tensor>("Total");
3234

35+
int* correct_data = correct->mutable_data<int>(ctx.GetPlace());
36+
int* total_data = total->mutable_data<int>(ctx.GetPlace());
3337
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
3438

3539
const int64_t* indices_data = indices->data<int64_t>();
@@ -55,7 +59,8 @@ class AccuracyKernel : public framework::OpKernel<T> {
5559
}
5660
}
5761

58-
// FIXME(typhoonzero): we don't accumulate the accuracy for now.
62+
*correct_data = num_correct;
63+
*total_data = num_samples;
5964
*accuracy_data =
6065
static_cast<float>(num_correct) / static_cast<float>(num_samples);
6166
}

paddle/operators/elementwise_add_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
3434
elementwise_add_grad, ops::ElementwiseOpGrad);
3535
REGISTER_OP_CPU_KERNEL(
3636
elementwise_add,
37-
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>);
37+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>,
38+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, double>,
39+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int>,
40+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int64_t>);
3841
REGISTER_OP_CPU_KERNEL(
3942
elementwise_add_grad,
40-
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>);
43+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>,
44+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, double>,
45+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int>,
46+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int64_t>);

paddle/operators/elementwise_div_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker,
3535
elementwise_div_grad, ops::ElementwiseOpGrad);
3636
REGISTER_OP_CPU_KERNEL(
3737
elementwise_div,
38-
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>);
38+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>,
39+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, double>,
40+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int>,
41+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int64_t>);
3942
REGISTER_OP_CPU_KERNEL(
4043
elementwise_div_grad,
41-
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>);
44+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>,
45+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, double>,
46+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int>,
47+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int64_t>);

paddle/operators/elementwise_mul_op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
3737
REGISTER_OP_CPU_KERNEL(
3838
elementwise_mul,
3939
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>,
40-
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>);
40+
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>,
41+
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int>,
42+
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int64_t>);
4143
REGISTER_OP_CPU_KERNEL(
4244
elementwise_mul_grad,
4345
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>,
44-
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>);
46+
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>,
47+
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int>,
48+
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int64_t>);

paddle/operators/elementwise_sub_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker,
3434
elementwise_sub_grad, ops::ElementwiseOpGrad);
3535
REGISTER_OP_CPU_KERNEL(
3636
elementwise_sub,
37-
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>);
37+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>,
38+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, double>,
39+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int>,
40+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int64_t>);
3841
REGISTER_OP_CPU_KERNEL(
3942
elementwise_sub_grad,
40-
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>);
43+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>,
44+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, double>,
45+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int>,
46+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int64_t>);

0 commit comments

Comments
 (0)