Skip to content

Commit 9e87786

Browse files
committed
Move shared memory inside code executed for >1 warp
1 parent 7c3c454 commit 9e87786

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

ggml/src/ggml-cuda/reduce_rows.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
3030

3131
// sum up partial sums
3232
sum = warp_reduce_sum(sum);
33-
__shared__ float s_sum[32];
34-
const int warp_id = threadIdx.x / WARP_SIZE;
35-
const int lane_id = threadIdx.x % WARP_SIZE;
36-
if (lane_id == 0) {
37-
s_sum[warp_id] = sum;
38-
}
39-
__syncthreads();
40-
sum = 0.0f;
4133
if constexpr (width > WARP_SIZE) {
34+
__shared__ float s_sum[32];
35+
const int warp_id = threadIdx.x / WARP_SIZE;
36+
const int lane_id = threadIdx.x % WARP_SIZE;
37+
if (lane_id == 0) {
38+
s_sum[warp_id] = sum;
39+
}
40+
__syncthreads();
41+
sum = 0.0f;
4242
static_assert((width <= 1024) && (width % WARP_SIZE) == 0, "unexpected block_size");
4343
if (lane_id < (blockDim.x / WARP_SIZE)) {
4444
sum = s_sum[lane_id];

0 commit comments

Comments
 (0)