Skip to content

Commit 61cae0d

Browse files
authored
[cherry-pick]Fixed a bug of log_softmax: op input was modified to 'nan' (#32937) (#33436)
使用op benchmark时发现,当输入数据量小于某个值时,python 端 log_softmax 接口的输入值经过计算过后 会被改变为nan。输出正常。 cherry-pick自 #32937
1 parent 8461ab1 commit 61cae0d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddle/fluid/operators/log_softmax_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ __global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src,
104104
#pragma unroll
105105
for (int it = 0; it < warp_iter; ++it) {
106106
int element_index = thread_in_warp_idx + it * kernel_warp_size;
107-
if (element_index < element_count) {
107+
if (element_index < effective_element_count) {
108108
dst[batch_id * element_count + element_index] =
109109
static_cast<T>(elements[it] - max_value - sum);
110110
} else {
@@ -226,7 +226,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
226226
#pragma unroll
227227
for (int iter = 0; iter < warp_iter; ++iter) {
228228
int element_index = thread_in_warp_idx + iter * kernel_warp_size;
229-
if (element_index < element_count) {
229+
if (element_index < effective_element_count) {
230230
grad_input[batch_id * element_count + element_index] = static_cast<T>(
231231
(grad_output_register[iter] - std::exp(output_register[iter]) * sum));
232232
}

0 commit comments

Comments
 (0)