Skip to content

Commit 954d680

Browse files
committed
fix test_parallel_do.py
1 parent 787bd1f commit 954d680

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

paddle/fluid/platform/profiler.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ Event::Event(EventType type, std::string name, uint32_t thread_id,
110110
has_cuda_ = dev_ctx ? platform::is_gpu_place(dev_ctx->GetPlace()) : false;
111111
if (has_cuda_) {
112112
auto* cuda_dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctx);
113+
PADDLE_ENFORCE(cudaSetDevice(
114+
boost::get<platform::CUDAPlace>(cuda_dev_ctx->GetPlace()).device));
113115
PADDLE_ENFORCE(cudaGetDevice(&device_));
114116
PADDLE_ENFORCE(cudaEventCreate(&event_));
115117
auto stream = cuda_dev_ctx->stream();

0 commit comments

Comments
 (0)