@@ -55,10 +55,10 @@ __global__ void __launch_bounds__(splitD, 2)
5555 const int stride_s0 = src0_nb1 / sizeof (float );
5656 const int stride_A = src3_nb1 / sizeof (float );
5757#pragma unroll
58- for (int j = 0 ; j < N; ++j )
58+ for (size_t n = 0 ; n < N; ++n )
5959 {
60- regA[j ] = A_block[threadIdx .x * stride_A + j ];
61- regs0[j ] = s0_block[threadIdx .x * stride_s0 + j ];
60+ regA[n ] = A_block[threadIdx .x * stride_A + n ];
61+ regs0[n ] = s0_block[threadIdx .x * stride_s0 + n ];
6262 }
6363#endif
6464
@@ -80,11 +80,11 @@ __global__ void __launch_bounds__(splitD, 2)
8080
8181 float sumf = 0 .0f ;
8282#pragma unroll
83- for (int j = 0 ; j < N; j ++)
83+ for (size_t n = 0 ; n < N; n ++)
8484 {
85- float state = regs0[j ] * expf (dt_soft_plus * regA[j ]) + smemB[j ] * x_dt;
86- sumf += state * smemC[j ];
87- regs0[j ] = state;
85+ float state = regs0[n ] * expf (dt_soft_plus * regA[n ]) + smemB[n ] * x_dt;
86+ sumf += state * smemC[n ];
87+ regs0[n ] = state;
8888 }
8989 y_block[i * stride_y + threadIdx .x ] = sumf;
9090 }
@@ -94,9 +94,9 @@ __global__ void __launch_bounds__(splitD, 2)
9494#else
9595 const int stride_s = stride_s0;
9696#pragma unroll
97- for (int j = 0 ; j < N; ++j )
97+ for (size_t n = 0 ; n < N; ++n )
9898 {
99- s_block[threadIdx .x * stride_s + j ] = regs0[j ];
99+ s_block[threadIdx .x * stride_s + n ] = regs0[n ];
100100 }
101101#endif
102102}
@@ -140,10 +140,10 @@ __global__ void __launch_bounds__(splitD, 2)
140140 const int stride_s0 = src0_nb1 / sizeof (float );
141141 const int stride_A = src3_nb1 / sizeof (float );
142142#pragma unroll
143- for (int j = 0 ; j < N; ++j )
143+ for (size_t n = 0 ; n < N; ++n )
144144 {
145- regA[j ] = A_block[threadIdx .x * stride_A + j ];
146- regs0[j ] = s0_block[threadIdx .x * stride_s0 + j ];
145+ regA[n ] = A_block[threadIdx .x * stride_A + n ];
146+ regs0[n ] = s0_block[threadIdx .x * stride_s0 + n ];
147147 }
148148#endif
149149
@@ -163,23 +163,23 @@ __global__ void __launch_bounds__(splitD, 2)
163163 float x_dt = x_block[threadIdx .x ] * dt_soft_plus;
164164 float sumf = 0 .0f ;
165165#pragma unroll
166- for (int j = 0 ; j < N; j ++)
166+ for (size_t n = 0 ; n < N; n ++)
167167 {
168- float state = regs0[j ] * expf (dt_soft_plus * regA[j ]) + smemB[j ] * x_dt;
169- sumf += state * smemC[j ];
170- regs0[j ] = state;
168+ float state = regs0[n ] * expf (dt_soft_plus * regA[n ]) + smemB[n ] * x_dt;
169+ sumf += state * smemC[n ];
170+ regs0[n ] = state;
171171 }
172172 y_block[threadIdx .x ] = sumf;
173173 }
174174
175175#ifdef USE_CUB
176176 BlockStoreS (block_store_tempS).Store (s_block, regs0);
177177#else
178- const int stride_s = s0 ;
178+ const int stride_s = stride_s0 ;
179179#pragma unroll
180- for (int j = 0 ; j < N; ++j )
180+ for (size_t n = 0 ; n < N; ++n )
181181 {
182- s_block[threadIdx .x * stride_s + j ] = regs0[j ];
182+ s_block[threadIdx .x * stride_s + n ] = regs0[n ];
183183 }
184184#endif
185185}
0 commit comments