@@ -16,6 +16,7 @@ limitations under the License. */
16
16
#include < thrust/reduce.h>
17
17
#include " paddle/operators/accuracy_op.h"
18
18
#include " paddle/platform/cuda_helper.h"
19
+ #include " paddle/platform/gpu_info.h"
19
20
20
21
namespace paddle {
21
22
namespace operators {
@@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
73
74
74
75
int num_samples = static_cast <int >(inference->dims ()[0 ]);
75
76
size_t infer_width = inference->dims ()[1 ];
76
- PADDLE_ENFORCE ( cudaMemset (accuracy_data, 0 , sizeof ( float )) );
77
- // cudaMemset((void**)&correct_data , 0, sizeof(float));
77
+ auto stream = ctx. cuda_device_context (). stream ( );
78
+ platform::GpuMemsetAsync (accuracy_data , 0 , sizeof (float ), stream );
78
79
79
80
if (num_samples == 0 ) {
80
81
return ;
81
82
}
82
- cudaMemcpy (total_data, &num_samples, sizeof (int ), cudaMemcpyHostToDevice);
83
+ platform::GpuMemcpyAsync (total_data, &num_samples, sizeof (int ),
84
+ cudaMemcpyHostToDevice, stream);
83
85
84
- AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS> <<<
85
- 1 , PADDLE_CUDA_NUM_THREADS, 0 , ctx.cuda_device_context(). stream() >>> (
86
+ AccuracyCudaKernel<
87
+ PADDLE_CUDA_NUM_THREADS> <<< 1 , PADDLE_CUDA_NUM_THREADS, 0 , stream>>> (
86
88
num_samples, infer_width, indices_data, label_data, correct_data,
87
89
accuracy_data);
88
90
89
91
int d_num_samples, d_num_correct;
90
92
float d_accuracy;
91
- cudaMemcpy (&d_num_correct, correct_data, sizeof (int ),
92
- cudaMemcpyDeviceToHost);
93
- cudaMemcpy (&d_num_samples, total_data, sizeof (int ), cudaMemcpyDeviceToHost);
94
- cudaMemcpy (&d_accuracy, accuracy_data, sizeof (float ),
95
- cudaMemcpyDeviceToHost);
93
+ platform::GpuMemcpyAsync (&d_num_correct, correct_data, sizeof (int ),
94
+ cudaMemcpyDeviceToHost, stream);
95
+ platform::GpuMemcpyAsync (&d_num_samples, total_data, sizeof (int ),
96
+ cudaMemcpyDeviceToHost, stream);
97
+ platform::GpuMemcpyAsync (&d_accuracy, accuracy_data, sizeof (float ),
98
+ cudaMemcpyDeviceToHost, stream);
96
99
}
97
100
};
98
101
0 commit comments