Skip to content

Commit 8b11dee

Browse files
authored
Hide latency of bias and gate-loading (#16847)
This is realised by loading them into registers before computation of the dot-product, effectively batching them together with said dot-product. As a lot of threads are alive here, the warp scheduler has enough threads available to effectively hide the cost of additionally loading those two floats.
1 parent b9ce940 commit 8b11dee

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,28 @@ static __global__ void mul_mat_vec_q(
190190

191191
const uint32_t channel_bias = ids ? channel_x : channel_dst;
192192

193+
float x_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
194+
float gate_biases[ncols_dst][rows_per_cuda_block] = { { 0.0f } };
193195
if constexpr (has_fusion) {
194196
if (use_bias) {
195197
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
198+
// 1. Hide latency by prefetching bias and gate here
199+
// 2. load only on threads that won't die after partial sum calculation
200+
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
201+
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
202+
for (int j = 0; j < ncols_dst; ++j) {
203+
x_biases[j][threadIdx.x] = x_bias[j * stride_col_dst + threadIdx.x];
204+
}
205+
}
196206
}
197207
if (use_gate_bias) {
198208
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
209+
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
210+
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
211+
for (int j = 0; j < ncols_dst; ++j) {
212+
gate_biases[j][threadIdx.x] = gate_bias[j * stride_col_dst + threadIdx.x];
213+
}
214+
}
199215
}
200216
}
201217

@@ -283,12 +299,12 @@ static __global__ void mul_mat_vec_q(
283299
float result = tmp[j][threadIdx.x];
284300
if constexpr (has_fusion) {
285301
if (use_bias) {
286-
result += x_bias[j*stride_col_dst + threadIdx.x];
302+
result += x_biases[j][threadIdx.x];
287303
}
288304
if (use_gate) {
289305
float gate_value = tmp_gate[j][threadIdx.x];
290306
if (use_gate_bias) {
291-
gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
307+
gate_value += gate_biases[j][threadIdx.x];
292308
}
293309
switch (active_glu) {
294310
case GGML_GLU_OP_SWIGLU:

0 commit comments

Comments
 (0)