Skip to content

Commit af6465a

Browse files
Apply suggestions from code review
Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 44987f7 commit af6465a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ static __global__ void mul_mat_vec_q(
199199
// 2. load only on threads that won't die after partial sum calculation
200200
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
201201
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
202+
#pragma unroll
202203
for (int j = 0; j < ncols_dst; ++j) {
203204
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
204205
}
@@ -208,6 +209,7 @@ static __global__ void mul_mat_vec_q(
208209
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
209210
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
210211
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
212+
#pragma unroll
211213
for (int j = 0; j < ncols_dst; ++j) {
212214
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
213215
}

0 commit comments

Comments
 (0)