File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -118,7 +118,7 @@ kernel void kernel_soft_max(
118118 device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
119119
120120 // parallel max
121- float lmax = psrc0[tpitg[0 ]];
121+ float lmax = tpitg[ 0 ] < ne00 ? psrc0[tpitg[0 ]] : -INFINITY ;
122122 for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00; i00 += ntg[0 ]) {
123123 lmax = MAX (lmax, psrc0[i00]);
124124 }
@@ -158,7 +158,7 @@ kernel void kernel_soft_max_4(
158158 device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
159159
160160 // parallel max
161- float4 lmax4 = psrc4[tpitg[0 ]];
161+ float4 lmax4 = tpitg[ 0 ] < ne00/ 4 ? psrc4[tpitg[0 ]] : -INFINITY ;
162162 for (int i00 = tpitg[0 ] + ntg[0 ]; i00 < ne00/4 ; i00 += ntg[0 ]) {
163163 lmax4 = fmax (lmax4, psrc4[i00]);
164164 }
You can’t perform that action at this time.
0 commit comments