Skip to content

Commit 58db07b

Browse files
qingqing01QiJune
authored andcommitted
Check errors for the cuda kernel calls. (#5436)
1 parent 6f43c93 commit 58db07b

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

paddle/framework/operator.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ void OperatorWithKernel::Run(const Scope& scope,
440440
}
441441

442442
kernel_iter->second->Compute(ctx);
443+
444+
// throws errors if have.
445+
dev_ctx.Finish();
443446
}
444447

445448
} // namespace framework

paddle/operators/math/detail/lstm_gpu_kernel.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,6 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
244244
op, value, grad, frameSize, batchSize, active_node, active_gate,
245245
active_state);
246246
}
247-
248-
cudaStreamSynchronize(stream);
249-
// TODO(qingqing): Add cuda error check for each kernel.
250-
cudaError_t err = cudaGetLastError();
251-
PADDLE_ENFORCE(err, cudaGetErrorString(err));
252247
}
253248

254249
} // namespace detail

paddle/platform/device_context.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ void CUDADeviceContext::Wait() const {
124124
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
125125
}
126126

127+
void CUDADeviceContext::Finish() const {
128+
Wait();
129+
PADDLE_ENFORCE(cudaGetLastError());
130+
}
131+
127132
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
128133
return eigen_device_.get();
129134
}

paddle/platform/device_context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class DeviceContext {
4646
DeviceType* GetEigenDevice() const;
4747

4848
virtual void Wait() const {}
49+
50+
virtual void Finish() const {}
4951
};
5052

5153
class CPUDeviceContext : public DeviceContext {
@@ -77,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
7779
/*! \brief Wait for all operations completion in the stream. */
7880
void Wait() const override;
7981

82+
/*! \brief Check potential errors for the cuda kernel calls. */
83+
void Finish() const override;
84+
8085
/*! \brief Return place in the device context. */
8186
Place GetPlace() const override;
8287

0 commit comments

Comments
 (0)