Skip to content

Commit 12858ba

Browse files
committed
"relauch ci"
1 parent fc117ec commit 12858ba

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

paddle/operators/accuracy_op.cu

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS;
2424
template <int BlockSize>
2525
__global__ void AccuracyCudaKernel(const int N, const int D,
2626
const int64_t* Xdata,
27-
const int64_t* labeldata, float* accuracy) {
27+
const int64_t* labeldata, int* correct_data,
28+
float* accuracy) {
2829
int count = 0;
2930
__shared__ int total[BlockSize];
3031

@@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
4344
// reduce the count with init value 0, and output accuracy.
4445
int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
4546
if (threadIdx.x == 0) {
47+
*correct_data = result;
4648
*accuracy = static_cast<float>(result) / static_cast<float>(N);
4749
}
4850
}
@@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
5658
auto* inference = ctx.Input<Tensor>("Out");
5759
auto* indices = ctx.Input<Tensor>("Indices");
5860
auto* label = ctx.Input<Tensor>("Label");
61+
5962
auto* accuracy = ctx.Output<Tensor>("Accuracy");
63+
auto* correct = ctx.Output<Tensor>("Correct");
64+
auto* total = ctx.Output<Tensor>("Total");
6065
// FIXME(typhoonzero): only support indices currently
6166
// if add support for output values, how to detect the data type?
6267
const int64_t* indices_data = indices->data<int64_t>();
6368
const int64_t* label_data = label->data<int64_t>();
69+
70+
int* correct_data = correct->mutable_data<int>(ctx.GetPlace());
71+
int* total_data = total->mutable_data<int>(ctx.GetPlace());
6472
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
6573

66-
size_t num_samples = inference->dims()[0];
74+
int num_samples = static_cast<int>(inference->dims()[0]);
6775
size_t infer_width = inference->dims()[1];
6876
PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
77+
// cudaMemset((void**)&correct_data, 0, sizeof(float));
6978

7079
if (num_samples == 0) {
7180
return;
7281
}
82+
cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice);
7383

7484
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
7585
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
76-
num_samples, infer_width, indices_data, label_data, accuracy_data);
86+
num_samples, infer_width, indices_data, label_data, correct_data,
87+
accuracy_data);
88+
89+
int d_num_samples, d_num_correct;
90+
float d_accuracy;
91+
cudaMemcpy(&d_num_correct, correct_data, sizeof(int),
92+
cudaMemcpyDeviceToHost);
93+
cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost);
94+
cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float),
95+
cudaMemcpyDeviceToHost);
7796
}
7897
};
7998

8099
} // namespace operators
81100
} // namespace paddle
82101

83-
// FIXME(typhoonzero): types of T is for infernece data.
84-
// label data is always int
102+
// FIXME(typhoonzero): types of T is for inference data.
103+
// label data is always int64
85104
REGISTER_OP_GPU_KERNEL(accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
86105
paddle::operators::AccuracyOpCUDAKernel<double>);

python/paddle/v2/framework/evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def reset(self, executor, reset_program=None):
4343
"""
4444
Clear metric states at the begin of each pass/user specified batch
4545
"""
46-
if program == None:
46+
if reset_program == None:
4747
reset_program = Program()
4848
else:
4949
reset_program = program
@@ -147,9 +147,9 @@ def _update_ops(self, input, label, k=1, **kwargs):
147147

148148
return acc_out
149149

150-
def eval(self, executor, program=None):
151-
if program != None:
152-
eval_program = program
150+
def eval(self, executor, eval_program=None):
151+
if eval_program != None:
152+
eval_program = eval_program
153153
else:
154154
eval_program = Program()
155155
block = eval_program.global_block()

0 commit comments

Comments
 (0)