Skip to content

Commit e97b898

Browse files
authored
"fix accuracy kernel bug" (#5673)
* "fix accuracy kernel bug" * "relauch ci"
1 parent f95c291 commit e97b898

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

paddle/operators/accuracy_op.cu

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <thrust/reduce.h>
1717
#include "paddle/operators/accuracy_op.h"
1818
#include "paddle/platform/cuda_helper.h"
19+
#include "paddle/platform/gpu_info.h"
1920

2021
namespace paddle {
2122
namespace operators {
@@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
7374

7475
int num_samples = static_cast<int>(inference->dims()[0]);
7576
size_t infer_width = inference->dims()[1];
76-
PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
77-
// cudaMemset((void**)&correct_data, 0, sizeof(float));
77+
auto stream = ctx.cuda_device_context().stream();
78+
platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream);
7879

7980
if (num_samples == 0) {
8081
return;
8182
}
82-
cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice);
83+
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
84+
cudaMemcpyHostToDevice, stream);
8385

84-
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
85-
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
86+
AccuracyCudaKernel<
87+
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
8688
num_samples, infer_width, indices_data, label_data, correct_data,
8789
accuracy_data);
8890

8991
int d_num_samples, d_num_correct;
9092
float d_accuracy;
91-
cudaMemcpy(&d_num_correct, correct_data, sizeof(int),
92-
cudaMemcpyDeviceToHost);
93-
cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost);
94-
cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float),
95-
cudaMemcpyDeviceToHost);
93+
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
94+
cudaMemcpyDeviceToHost, stream);
95+
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
96+
cudaMemcpyDeviceToHost, stream);
97+
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
98+
cudaMemcpyDeviceToHost, stream);
9699
}
97100
};
98101

paddle/platform/gpu_info.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,10 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
109109
cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream),
110110
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer");
111111
}
112+
113+
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
114+
PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream),
115+
"cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync");
116+
}
112117
} // namespace platform
113118
} // namespace paddle

paddle/platform/gpu_info.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ void GpuMemcpySync(void *dst, const void *src, size_t count,
6060
void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
6161
size_t count, cudaStream_t stream);
6262

63+
//! Set memory dst with value count size asynchronously
64+
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);
65+
6366
} // namespace platform
6467
} // namespace paddle
6568

0 commit comments

Comments
 (0)