Skip to content

Commit d3dc9dd

Browse files
CUDA: Remove unneded bias/gate dims in fused mmvq (#16858)
* CUDA: Remove unneded bias/gate dims in fused mmvq Pointed out [here](#16847 (comment)) that only a single value is needed per target col per thread * Apply suggestions from code review Co-authored-by: Johannes Gäßler <[email protected]> * Fix "Error 991-D: extra braces are nonstandard" during compilation --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent bea0452 commit d3dc9dd

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,26 +190,28 @@ static __global__ void mul_mat_vec_q(
190190

191191
const uint32_t channel_bias = ids ? channel_x : channel_dst;
192192

193-
float x_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
194-
float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
193+
float x_biases[ncols_dst] = { 0.0f };
194+
float gate_biases[ncols_dst] = { 0.0f };
195195
if constexpr (has_fusion) {
196196
if (use_bias) {
197197
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
198198
// 1. Hide latency by prefetching bias and gate here
199199
// 2. load only on threads that won't die after partial sum calculation
200200
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
201201
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
202+
#pragma unroll
202203
for (int j = 0; j < ncols_dst; ++j) {
203-
x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
204+
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
204205
}
205206
}
206207
}
207208
if (use_gate_bias) {
208209
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
209210
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
210211
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
212+
#pragma unroll
211213
for (int j = 0; j < ncols_dst; ++j) {
212-
gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
214+
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
213215
}
214216
}
215217
}
@@ -299,12 +301,12 @@ static __global__ void mul_mat_vec_q(
299301
float result = tmp[j][threadIdx.x];
300302
if constexpr (has_fusion) {
301303
if (use_bias) {
302-
result += x_biases[j][threadIdx.x];
304+
result += x_biases[j];
303305
}
304306
if (use_gate) {
305307
float gate_value = tmp_gate[j][threadIdx.x];
306308
if (use_gate_bias) {
307-
gate_value += gate_biases[j][threadIdx.x];
309+
gate_value += gate_biases[j];
308310
}
309311
switch (active_glu) {
310312
case GGML_GLU_OP_SWIGLU:

0 commit comments

Comments
 (0)