Skip to content

Commit bdf4f0d

Browse files
ikawrakowIwan Kawrakow
andauthored
Even more fused ops (ikawrakow#868)
* Fuse Q, K, V gemv+add * More gemv+add fusing * Faster copy when tensors are contiguous Relevant for storing data into the KV cache. I see ~1% speedup for fast models (Ling-mini-2.0, gpt-oss-20b, etc.) * Cleanup * Make sure the bias really is 1 row to use fusion --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent d894998 commit bdf4f0d

File tree

6 files changed

+159
-15
lines changed

6 files changed

+159
-15
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2078,9 +2078,43 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
20782078
src0->type, stream);
20792079
CUDA_CHECK(cudaGetLastError());
20802080

2081-
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
2082-
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
2083-
CUDA_CHECK(cudaGetLastError());
2081+
// The code below handles the case when Q, K, V have a bias applied after the resepctive matrix multiplication.
2082+
// In that case the graph contains mul_mat(Q) -> mul_mat(K) -> mul_mat(V) -> add(Q) -> add(K) -> add(V)
2083+
if (cgraph && node_n + 5 < cgraph->n_nodes &&
2084+
cgraph->nodes[node_n+1]->op == GGML_OP_MUL_MAT &&
2085+
cgraph->nodes[node_n+2]->op == GGML_OP_MUL_MAT &&
2086+
ggml_is_quantized(cgraph->nodes[node_n+1]->src[0]->type) &&
2087+
ggml_is_quantized(cgraph->nodes[node_n+2]->src[0]->type) &&
2088+
cgraph->nodes[node_n+3]->op == GGML_OP_ADD &&
2089+
cgraph->nodes[node_n+4]->op == GGML_OP_ADD &&
2090+
cgraph->nodes[node_n+5]->op == GGML_OP_ADD &&
2091+
cgraph->nodes[node_n+0] == cgraph->nodes[node_n+3]->src[0] &&
2092+
cgraph->nodes[node_n+1] == cgraph->nodes[node_n+4]->src[0] &&
2093+
cgraph->nodes[node_n+2] == cgraph->nodes[node_n+5]->src[0]) {
2094+
for (int i = 0; i < 3; ++i) {
2095+
auto src0_i = cgraph->nodes[node_n+i]->src[0];
2096+
ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0_i, src1, cgraph->nodes[node_n+i], cgraph->nodes[node_n+i+3]->src[1],
2097+
(const char *)src0_i->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+i]->data,
2098+
0, src0_i->ne[1], src1->ne[1], ne10_padded, stream);
2099+
CUDA_CHECK(cudaGetLastError());
2100+
}
2101+
node_n += 5;
2102+
} else if (cgraph && node_n + 1 < cgraph->n_nodes &&
2103+
cgraph->nodes[node_n+1]->op == GGML_OP_ADD &&
2104+
dst == cgraph->nodes[node_n+1]->src[0] &&
2105+
dst->ne[0] == cgraph->nodes[node_n+1]->src[1]->ne[0] &&
2106+
cgraph->nodes[node_n+1]->src[1]->type == GGML_TYPE_F32 &&
2107+
ggml_nrows(cgraph->nodes[node_n+1]->src[1]) == 1) {
2108+
// We have a bias applied after the matrix multiplication and we can fuse it
2109+
ggml_cuda_op_mul_mat_vec_q_biased(ctx, dst->src[0], src1, cgraph->nodes[node_n+1], cgraph->nodes[node_n+1]->src[1],
2110+
(const char *)dst->src[0]->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+1]->data,
2111+
0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
2112+
++node_n;
2113+
} else {
2114+
ggml_cuda_op_mul_mat_vec_q(ctx, src0, src1, dst, (const char *)src0->data, nullptr, src1_quantized.get(), (float *)dst->data,
2115+
0, src0->ne[1], src1->ne[1], ne10_padded, stream);
2116+
CUDA_CHECK(cudaGetLastError());
2117+
}
20842118
} else {
20852119
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0->type, stream);
20862120
CUDA_CHECK(cudaGetLastError());
@@ -2101,8 +2135,21 @@ static int ggml_cuda_mul_mat_q(ggml_backend_cuda_context & ctx, const ggml_tenso
21012135
if (dst->op != GGML_OP_MUL_MAT || dst->src[1] != src1 || !ggml_is_quantized(dst->src[0]->type)) break;
21022136
if (!is_gemv && mmq_get_q8_1_ds_layout(src0->type) != mmq_get_q8_1_ds_layout(dst->src[0]->type)) break;
21032137
if (is_gemv) {
2104-
ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
2105-
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
2138+
if (node_n + 1 < cgraph->n_nodes &&
2139+
cgraph->nodes[node_n+1]->op == GGML_OP_ADD &&
2140+
dst == cgraph->nodes[node_n+1]->src[0] &&
2141+
dst->ne[0] == cgraph->nodes[node_n+1]->src[1]->ne[0] &&
2142+
cgraph->nodes[node_n+1]->src[1]->type == GGML_TYPE_F32 &&
2143+
ggml_nrows(cgraph->nodes[node_n+1]->src[1]) == 1) {
2144+
// We have a bias applied after the matrix multiplication and we can fuse it
2145+
ggml_cuda_op_mul_mat_vec_q_biased(ctx, dst->src[0], src1, cgraph->nodes[node_n+1], cgraph->nodes[node_n+1]->src[1],
2146+
(const char *)dst->src[0]->data, nullptr, src1_quantized.get(), (float *)cgraph->nodes[node_n+1]->data,
2147+
0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
2148+
++node_n;
2149+
} else {
2150+
ggml_cuda_op_mul_mat_vec_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
2151+
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);
2152+
}
21062153
} else {
21072154
ggml_cuda_op_mul_mat_q(ctx, dst->src[0], src1, dst, (const char *)dst->src[0]->data, nullptr, src1_quantized.get(),
21082155
(float *)dst->data, 0, dst->src[0]->ne[1], src1->ne[1], ne10_padded, stream);

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,25 @@ void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
313313
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
314314
}
315315

316+
static __global__ void k_fast_add(int64_t ne0, int64_t nelem, const float * x, const float * y, float * z) {
317+
int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
318+
if (i >= nelem) {
319+
return;
320+
}
321+
z[i] = x[i] + y[i % ne0];
322+
}
323+
316324
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
325+
if (ggml_nrows(dst->src[1]) == 1 && dst->src[0]->ne[0] == dst->src[1]->ne[0] &&
326+
dst->type == GGML_TYPE_F32 && dst->src[0]->type == GGML_TYPE_F32 && dst->src[1]->type == GGML_TYPE_F32 &&
327+
ggml_are_same_shape(dst, dst->src[0]) && ggml_is_contiguous(dst)) {
328+
constexpr int kBlockSize = 256;
329+
auto nelem = ggml_nelements(dst);
330+
int nblocks = (nelem + kBlockSize - 1)/kBlockSize;
331+
k_fast_add<<<nblocks, kBlockSize, 0, ctx.stream()>>>(dst->ne[0], nelem,
332+
(const float *)dst->src[0]->data, (const float *)dst->src[1]->data, (float *)dst->data);
333+
return;
334+
}
317335
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
318336
}
319337

ggml/src/ggml-cuda/cpy.cu

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne
3838
cpy_1(cx + x_offset, cdst + dst_offset);
3939
}
4040

41+
template <typename src_t, typename dst_t>
42+
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst_direct, const int ne,
43+
char ** cdst_indirect, int graph_cpynode_index) {
44+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
45+
46+
if (i >= ne) {
47+
return;
48+
}
49+
50+
auto dst = (cdst_indirect != nullptr) ? (dst_t *)cdst_indirect[graph_cpynode_index] : (dst_t *)cdst_direct;
51+
auto src = (const src_t *)cx;
52+
53+
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
54+
dst[i] = __float2bfloat16(src[i]);
55+
} else {
56+
dst[i] = (dst_t)src[i];
57+
}
58+
}
59+
4160
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
4261
float * cdstf = (float *)(cdsti);
4362

@@ -163,6 +182,16 @@ static void ggml_cpy_flt_cuda(
163182
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
164183
}
165184

185+
template<typename src_t, typename dst_t>
186+
static void ggml_cpy_flt_contiguous_cuda(
187+
const char * cx, char * cdst, const int ne,
188+
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
189+
190+
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
191+
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
192+
(cx, cdst, ne, cdst_indirect, graph_cpynode_index++);
193+
}
194+
166195
static void ggml_cpy_f32_q8_0_cuda(
167196
const char * cx, char * cdst, const int ne,
168197
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -404,6 +433,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
404433
char * src0_ddc = (char *) src0->data;
405434
char * src1_ddc = (char *) src1->data;
406435

436+
bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1);
437+
407438
char ** dest_ptrs_d = nullptr;
408439
int graph_cpynode_index = -1;
409440
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
@@ -429,11 +460,23 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
429460
}
430461
}
431462
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
432-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
463+
if (fast_cpy) {
464+
ggml_cpy_flt_contiguous_cuda<float, float>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
465+
} else {
466+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
467+
}
433468
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
434-
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
469+
if (fast_cpy) {
470+
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
471+
} else {
472+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
473+
}
435474
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
436-
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
475+
if (fast_cpy) {
476+
ggml_cpy_flt_contiguous_cuda<float, half>(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index);
477+
} else {
478+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
479+
}
437480
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
438481
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
439482
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -505,6 +548,7 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
505548
}
506549

507550
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
551+
bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1);
508552
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
509553
// Prioritize CUDA graph compatibility over direct memory copy optimization.
510554
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
@@ -514,11 +558,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
514558
return nullptr;
515559
}
516560
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
517-
return (void*) cpy_flt<cpy_1_flt<float, float>>;
561+
return fast_cpy ? (void *)cpy_flt_contiguous<float, float> : (void*) cpy_flt<cpy_1_flt<float, float>>;
518562
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
519-
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
563+
return fast_cpy ? (void *)cpy_flt_contiguous<float, nv_bfloat16> : (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
520564
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
521-
return (void*) cpy_flt<cpy_1_flt<float, half>>;
565+
return fast_cpy ? (void *)cpy_flt_contiguous<float, half> : (void*) cpy_flt<cpy_1_flt<float, half>>;
522566
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
523567
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
524568
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ void ggml_cuda_op_mul_mat_vec_q_3D(
168168
GGML_UNUSED(src1_ddf_i);
169169
}
170170

171-
void ggml_cuda_op_mul_mat_vec_q(
171+
void ggml_cuda_op_mul_mat_vec_q_biased(
172172
ggml_backend_cuda_context & ctx,
173-
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
173+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias,
174+
const char * src0_dd_i, const float * src1_ddf_i,
174175
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
175176
const int64_t src1_padded_row_size, cudaStream_t stream) {
176177

@@ -180,14 +181,37 @@ void ggml_cuda_op_mul_mat_vec_q(
180181

181182
const int64_t ne0 = dst->ne[0];
182183

184+
if (bias) {
185+
if (bias->ne[0] != ne0) {
186+
printf("Oops: bias %s is %ld x %ld x %ld x %ld, dst %s is %ld x %ld x %ld x %ld\n",
187+
bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3],
188+
dst->name, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
189+
}
190+
GGML_ASSERT(bias->ne[0] == ne0);
191+
GGML_ASSERT(bias->type == GGML_TYPE_F32);
192+
if (ggml_nrows(bias) != 1) {
193+
printf("Oops: bias %s is %ld x %ld x %ld x %ld\n", bias->name, bias->ne[0], bias->ne[1], bias->ne[2], bias->ne[3]);
194+
}
195+
GGML_ASSERT(ggml_nrows(bias) == 1);
196+
}
197+
183198
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
184199
ne00, ne0, 1, 0, 0, 0, 0, 0,
185-
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, nullptr, nullptr,
200+
src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, bias ? bias->data : nullptr, nullptr,
186201
row_low, row_high, src1_ncols,
187202
src1_padded_row_size, GGML_UNARY_OP_COUNT, stream);
188203

189204
GGML_UNUSED(src1_ddf_i);
190205
}
206+
void ggml_cuda_op_mul_mat_vec_q(
207+
ggml_backend_cuda_context & ctx,
208+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
209+
const char * src0_dd_i, const float * src1_ddf_i,
210+
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
211+
const int64_t src1_padded_row_size, cudaStream_t stream) {
212+
ggml_cuda_op_mul_mat_vec_q_biased(ctx, src0, src1, dst, nullptr, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, row_low, row_high, src1_ncols,
213+
src1_padded_row_size, stream);
214+
}
191215

192216
void ggml_cuda_op_mul_mat_vec_q_id(
193217
ggml_backend_cuda_context & ctx,

ggml/src/ggml-cuda/mmvq.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,20 @@
99

1010
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
1111

12+
void ggml_cuda_op_mul_mat_vec_q_biased(ggml_backend_cuda_context & ctx,
13+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * bias,
14+
const char * src0_dd_i, const float * src1_ddf_i,
15+
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
16+
const int64_t src1_padded_row_size, cudaStream_t stream);
17+
1218
void ggml_cuda_op_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
13-
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
19+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
20+
const char * src0_dd_i, const float * src1_ddf_i,
1421
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
1522
const int64_t src1_padded_row_size, cudaStream_t stream);
1623

1724
bool ggml_cuda_mmvq_type_supported(ggml_type src0_type);
25+
1826
void ggml_cuda_op_mul_mat_vec_q_3D(ggml_backend_cuda_context & ctx,
1927
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
2028
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,

src/llama-build-context.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,14 +1240,17 @@ std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> llm_build_context::llm_buil
12401240
if (bq) {
12411241
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
12421242
cb(Qcur, "Qcur", il);
1243+
ggml_build_forward_expand(gf, Qcur);
12431244
}
12441245
if (bk) {
12451246
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
12461247
cb(Kcur, "Kcur", il);
1248+
ggml_build_forward_expand(gf, Kcur);
12471249
}
12481250
if (bv) {
12491251
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
12501252
cb(Vcur, "Vcur", il);
1253+
ggml_build_forward_expand(gf, Vcur);
12511254
}
12521255
return {Qcur, Kcur, Vcur};
12531256
}

0 commit comments

Comments
 (0)