Skip to content

Commit 254f873

Browse files
authored
fix WhereGradCUDAKernel (#74332)
1 parent 4cebc8c commit 254f873

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

paddle/phi/kernels/gpu/where_grad_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ namespace phi {
2222
template <typename T, typename IndexT>
2323
__global__ void WhereGradCUDAKernel(
2424
const IndexT N, const T* dout, const bool* cond, T* dx, T* dy) {
25-
IndexT idx = blockDim.x * blockIdx.x + threadIdx.x;
26-
for (; idx < N; idx += blockDim.x * gridDim.x) {
25+
CUDA_KERNEL_LOOP_TYPE(idx, N, IndexT) {
2726
if (dx != nullptr) {
2827
dx[idx] = cond[idx] ? dout[idx] : static_cast<T>(0.);
2928
}

0 commit comments

Comments
 (0)