@@ -24,7 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS;
24
24
template <int BlockSize>
25
25
__global__ void AccuracyCudaKernel (const int N, const int D,
26
26
const int64_t * Xdata,
27
- const int64_t * labeldata, float * accuracy) {
27
+ const int64_t * labeldata, int * correct_data,
28
+ float * accuracy) {
28
29
int count = 0 ;
29
30
__shared__ int total[BlockSize];
30
31
@@ -43,6 +44,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
43
44
// reduce the count with init value 0, and output accuracy.
44
45
int result = thrust::reduce (thrust::device, total, total + BlockSize, 0 );
45
46
if (threadIdx .x == 0 ) {
47
+ *correct_data = result;
46
48
*accuracy = static_cast <float >(result) / static_cast <float >(N);
47
49
}
48
50
}
@@ -56,31 +58,48 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
56
58
auto * inference = ctx.Input <Tensor>(" Out" );
57
59
auto * indices = ctx.Input <Tensor>(" Indices" );
58
60
auto * label = ctx.Input <Tensor>(" Label" );
61
+
59
62
auto * accuracy = ctx.Output <Tensor>(" Accuracy" );
63
+ auto * correct = ctx.Output <Tensor>(" Correct" );
64
+ auto * total = ctx.Output <Tensor>(" Total" );
60
65
// FIXME(typhoonzero): only support indices currently
61
66
// if add support for output values, how to detect the data type?
62
67
const int64_t * indices_data = indices->data <int64_t >();
63
68
const int64_t * label_data = label->data <int64_t >();
69
+
70
+ int * correct_data = correct->mutable_data <int >(ctx.GetPlace ());
71
+ int * total_data = total->mutable_data <int >(ctx.GetPlace ());
64
72
float * accuracy_data = accuracy->mutable_data <float >(ctx.GetPlace ());
65
73
66
- size_t num_samples = inference->dims ()[0 ];
74
+ int num_samples = static_cast < int >( inference->dims ()[0 ]) ;
67
75
size_t infer_width = inference->dims ()[1 ];
68
76
PADDLE_ENFORCE (cudaMemset (accuracy_data, 0 , sizeof (float )));
77
+ // cudaMemset((void**)&correct_data, 0, sizeof(float));
69
78
70
79
if (num_samples == 0 ) {
71
80
return ;
72
81
}
82
+ cudaMemcpy (total_data, &num_samples, sizeof (int ), cudaMemcpyHostToDevice);
73
83
74
84
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
75
85
1 , PADDLE_CUDA_NUM_THREADS, 0 , ctx.cuda_device_context().stream()>>> (
76
- num_samples, infer_width, indices_data, label_data, accuracy_data);
86
+ num_samples, infer_width, indices_data, label_data, correct_data,
87
+ accuracy_data);
88
+
89
+ int d_num_samples, d_num_correct;
90
+ float d_accuracy;
91
+ cudaMemcpy (&d_num_correct, correct_data, sizeof (int ),
92
+ cudaMemcpyDeviceToHost);
93
+ cudaMemcpy (&d_num_samples, total_data, sizeof (int ), cudaMemcpyDeviceToHost);
94
+ cudaMemcpy (&d_accuracy, accuracy_data, sizeof (float ),
95
+ cudaMemcpyDeviceToHost);
77
96
}
78
97
};
79
98
80
99
} // namespace operators
81
100
} // namespace paddle
82
101
83
- // FIXME(typhoonzero): types of T is for infernece data.
84
- // label data is always int
102
+ // FIXME(typhoonzero): types of T is for inference data.
103
+ // label data is always int64
85
104
REGISTER_OP_GPU_KERNEL (accuracy, paddle::operators::AccuracyOpCUDAKernel<float >,
86
105
paddle::operators::AccuracyOpCUDAKernel<double >);
0 commit comments