Skip to content

Commit 5f19c5b

Browse files
authored
Merge pull request #7 from ubergarm/ug/port-sweep-bench
Ug/port sweep bench
2 parents 5477621 + 045e213 commit 5f19c5b

20 files changed

+645
-483
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 361 additions & 305 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ layout (constant_id = 4) const uint32_t HSV = 32;
99
layout (constant_id = 5) const uint32_t Clamp = 0;
1010
layout (constant_id = 6) const uint32_t D_split = 16;
1111

12+
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
13+
const uint32_t HSK_pad = (HSK + 15) & ~15;
14+
const uint32_t HSV_pad = (HSV + 15) & ~15;
15+
1216
layout (push_constant) uniform parameter {
1317
uint32_t N;
1418
uint32_t KV;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
4646
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
4747
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
4848

49-
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
49+
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
5050
shared f16vec4 Qf[Br * qstride];
5151

5252
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
5353
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
5454
shared ACC_TYPE sfsh[Bc * sfshstride];
5555

56-
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
56+
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
5757
shared f16vec4 ksh[Bc * kshstride];
5858

5959
shared float slope[Br];
@@ -74,6 +74,21 @@ void main() {
7474

7575
#define tile_row(r) (row_tid * rows_per_thread + (r))
7676

77+
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
78+
if ((HSK % 16) != 0) {
79+
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
80+
if (i + tid < Br * qstride) {
81+
Qf[i + tid] = f16vec4(0);
82+
}
83+
}
84+
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
85+
if (i + tid < Bc * kshstride) {
86+
ksh[i + tid] = f16vec4(0);
87+
}
88+
}
89+
barrier();
90+
}
91+
7792
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
7893

7994
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
@@ -151,14 +166,14 @@ void main() {
151166
}
152167
barrier();
153168

154-
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
169+
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
155170
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
156171
// This is written transposed in order to allow for N being 8 if implementations need it
157172
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
158173
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
159174
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
160175

161-
for (uint32_t d = 0; d < HSK / 16; ++d) {
176+
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
162177
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
163178

164179
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,16 @@ void main() {
104104
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
105105
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
106106

107-
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
108-
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
107+
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
108+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
109109

110110
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
111-
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
111+
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
112112

113-
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
113+
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
114114
Qf16 *= float16_t(p.scale);
115115

116-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
116+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
117117

118118
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
119119

@@ -140,10 +140,10 @@ void main() {
140140

141141
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
142142

143-
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
143+
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
144144

145145
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
146-
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
146+
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
147147
S = coopMatMulAdd(Qf16, K_T, S);
148148

149149
if (p.logit_softcap != 0.0f) {
@@ -208,31 +208,31 @@ void main() {
208208
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
209209
rowsum = coopMatMulAdd(P_A, One, rowsum);
210210

211-
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
211+
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
212212
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
213-
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
213+
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
214214

215215
L = eM*L + rowsum;
216216

217217
// This is the "diagonal" matrix in the paper, but since we do componentwise
218218
// multiply rather than matrix multiply it has the diagonal element smeared
219219
// across the row
220-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
220+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
221221

222222
// resize eM by using smear/reduce
223223
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
224224

225225
// multiply with fp16 accumulation, then add to O.
226-
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
226+
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
227227
PV = coopMatMulAdd(P_A, V, PV);
228228

229-
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
229+
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
230230
}
231231

232232
// If there is split_k, then the split_k resolve shader does the final
233233
// division by L. Store the intermediate O value and per-row m and L values.
234234
if (p.k_num > 1) {
235-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
235+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
236236

237237
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
238238
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
@@ -243,16 +243,16 @@ void main() {
243243
return;
244244
}
245245

246-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
246+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
247247

248248
// resize L by using smear/reduce
249249
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
250250

251251
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
252-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
252+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
253253
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
254254

255-
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
255+
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
256256

257257
// resize M by using smear/reduce
258258
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
@@ -285,7 +285,7 @@ void main() {
285285

286286
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
287287

288-
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
288+
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
289289
if (p.gqa_ratio > 1) {
290290
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
291291
} else {
@@ -295,6 +295,6 @@ void main() {
295295
// permute dimensions
296296
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
297297

298-
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
298+
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
299299
}
300300
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
#ifdef COOPMAT
1818
#extension GL_KHR_cooperative_matrix : enable
1919
#extension GL_KHR_memory_scope_semantics : enable
20+
#endif
21+
22+
#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
2023
#extension GL_KHR_shader_subgroup_basic : enable
2124
#extension GL_KHR_shader_subgroup_ballot : enable
2225
#endif
@@ -108,8 +111,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
108111
#ifdef MUL_MAT_ID
109112
shared u16vec2 row_ids[4096];
110113
uint _ne1;
111-
#ifdef COOPMAT
114+
115+
#ifdef MUL_MAT_ID_USE_SUBGROUPS
112116
shared uvec4 ballots_sh[NUM_WARPS];
117+
113118
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
114119
_ne1 = 0;
115120
uint num_elements = p.nei1 * p.nei0;
@@ -168,7 +173,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
168173
}
169174
barrier();
170175
}
171-
#endif
176+
#endif // MUL_MAT_ID_USE_SUBGROUPS
172177
#endif // MUL_MAT_ID
173178

174179
#ifdef COOPMAT
@@ -235,7 +240,7 @@ void main() {
235240
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
236241

237242
#ifdef MUL_MAT_ID
238-
#ifdef COOPMAT
243+
#ifdef MUL_MAT_ID_USE_SUBGROUPS
239244
if (bitCount(p.nei0) == 1) {
240245
load_row_ids(expert_idx, true);
241246
} else {

ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ layout (push_constant) uniform parameter2
2323
uint rms_partials;
2424
} p;
2525

26-
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
27-
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
26+
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
27+
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
28+
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
29+
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
30+
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
2831

2932
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
3033

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ const std::vector<std::string> type_names = {
6868
"bf16",
6969
};
7070

71+
enum MatMulIdType {
72+
NONE,
73+
DEFAULT,
74+
SUBGROUP,
75+
};
76+
7177
namespace {
7278
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
7379
#ifdef _WIN32
@@ -293,7 +299,7 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
293299
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
294300
}
295301

296-
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
302+
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
297303
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
298304
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
299305
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@@ -303,9 +309,13 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
303309
};
304310
std::string shader_name = "matmul";
305311

306-
if (matmul_id) {
312+
if (matmul_id_type == MatMulIdType::DEFAULT) {
307313
base_dict["MUL_MAT_ID"] = "1";
308314
shader_name = "matmul_id";
315+
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
316+
base_dict["MUL_MAT_ID"] = "1";
317+
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
318+
shader_name = "matmul_id_subgroup";
309319
}
310320

311321
if (fp16) {
@@ -389,7 +399,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
389399
}
390400

391401
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
392-
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
402+
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
393403
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
394404
}
395405
#endif
@@ -401,26 +411,28 @@ void process_shaders() {
401411
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
402412

403413
// matmul
404-
for (const auto& matmul_id : {false, true}) {
414+
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
405415
// No coopmats
406416
// fp32
407-
matmul_shaders(false, matmul_id, false, false, false);
417+
matmul_shaders(false, matmul_id_type, false, false, false);
408418

409419
// fp16, fp32acc and fp16acc
410-
matmul_shaders(true, matmul_id, false, false, false);
411-
matmul_shaders(true, matmul_id, false, false, true);
420+
matmul_shaders(true, matmul_id_type, false, false, false);
421+
matmul_shaders(true, matmul_id_type, false, false, true);
412422

423+
if (matmul_id_type != MatMulIdType::DEFAULT) {
413424
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
414-
// Coopmat, fp32acc and fp16acc
415-
matmul_shaders(true, matmul_id, true, false, false);
416-
matmul_shaders(true, matmul_id, true, false, true);
425+
// Coopmat, fp32acc and fp16acc
426+
matmul_shaders(true, matmul_id_type, true, false, false);
427+
matmul_shaders(true, matmul_id_type, true, false, true);
417428
#endif
418429

419430
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
420-
// Coopmat2, fp32acc and fp16acc
421-
matmul_shaders(true, matmul_id, false, true, false);
422-
matmul_shaders(true, matmul_id, false, true, true);
431+
// Coopmat2, fp32acc and fp16acc
432+
matmul_shaders(true, matmul_id_type, false, true, false);
433+
matmul_shaders(true, matmul_id_type, false, true, true);
423434
#endif
435+
}
424436
}
425437

426438
// flash attention

src/llama-hparams.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
153153

154154
GGML_ABORT("fatal error");
155155
}
156+
157+
bool llama_hparams::has_kv(uint32_t il) const {
158+
if (n_layer_kv_from_start >= 0) {
159+
if (il < (uint32_t) n_layer_kv_from_start) {
160+
return true;
161+
}
162+
163+
return false;
164+
}
165+
166+
// by default, all layers have kv
167+
return true;
168+
}
169+
170+
uint32_t llama_hparams::n_layer_kv() const {
171+
uint32_t res = 0;
172+
173+
for (uint32_t il = 0; il < n_layer; ++il) {
174+
if (has_kv(il)) {
175+
res++;
176+
}
177+
}
178+
179+
return res;
180+
}

src/llama-hparams.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct llama_hparams {
4141
uint32_t n_embd;
4242
uint32_t n_embd_features = 0;
4343
uint32_t n_layer;
44+
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
4445
uint32_t n_rot;
4546
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
4647
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -221,6 +222,11 @@ struct llama_hparams {
221222
uint32_t n_pos_per_embd() const;
222223

223224
bool is_swa(uint32_t il) const;
225+
226+
bool has_kv(uint32_t il) const;
227+
228+
// number of layers for which has_kv() returns true
229+
uint32_t n_layer_kv() const;
224230
};
225231

226232
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

0 commit comments

Comments
 (0)