Skip to content

Commit cf95cd4

Browse files
Create llama-hparams-ms.cpp
Signed-off-by: Brad Hutchings <[email protected]>
1 parent 4eb9006 commit cf95cd4

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

src/llama-hparams-ms.cpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include "llama-hparams.h"
2+
3+
#include "ggml.h"
4+
5+
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
6+
for (uint32_t il = 0; il < n_layer; ++il) {
7+
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
8+
}
9+
}
10+
11+
bool llama_hparams::is_swa_any() const {
12+
for (uint32_t il = 0; il < n_layer; ++il) {
13+
if (swa_layers[il]) {
14+
return true;
15+
}
16+
}
17+
18+
return false;
19+
}
20+
21+
uint32_t llama_hparams::n_head(uint32_t il) const {
22+
if (il < n_layer) {
23+
return n_head_arr[il];
24+
}
25+
26+
GGML_ABORT("fatal error");
27+
}
28+
29+
uint32_t llama_hparams::n_head_kv(uint32_t il) const {
30+
if (il < n_layer) {
31+
return n_head_kv_arr[il];
32+
}
33+
34+
GGML_ABORT("fatal error");
35+
}
36+
37+
uint32_t llama_hparams::n_ff(uint32_t il) const {
38+
if (il < n_layer) {
39+
return n_ff_arr[il];
40+
}
41+
42+
GGML_ABORT("fatal error");
43+
}
44+
45+
uint32_t llama_hparams::n_gqa(uint32_t il) const {
46+
const uint32_t n_head = this->n_head(il);
47+
const uint32_t n_head_kv = this->n_head_kv(il);
48+
49+
if (n_head_kv == 0) {
50+
return 0;
51+
}
52+
53+
return n_head/n_head_kv;
54+
}
55+
56+
uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
57+
const uint32_t n_head_kv = this->n_head_kv(il);
58+
59+
return n_embd_head_k * n_head_kv;
60+
}
61+
62+
uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
63+
const uint32_t n_head_kv = this->n_head_kv(il);
64+
65+
return n_embd_head_v * n_head_kv;
66+
}
67+
68+
bool llama_hparams::is_n_embd_k_gqa_variable() const {
69+
const uint32_t val = n_embd_k_gqa();
70+
for (uint32_t il = 0; il < n_layer; ++il) {
71+
if (val != n_embd_k_gqa(il)) {
72+
return true;
73+
}
74+
}
75+
76+
return false;
77+
}
78+
79+
bool llama_hparams::is_n_embd_v_gqa_variable() const {
80+
const uint32_t val = n_embd_v_gqa();
81+
for (uint32_t il = 0; il < n_layer; ++il) {
82+
if (val != n_embd_v_gqa(il)) {
83+
return true;
84+
}
85+
}
86+
87+
return false;
88+
}
89+
90+
uint32_t llama_hparams::n_embd_k_gqa_max() const {
91+
uint32_t val = n_embd_k_gqa();
92+
for (uint32_t il = 0; il < n_layer; ++il) {
93+
#ifndef COSMOCC
94+
val = std::max(val, n_embd_k_gqa(il));
95+
#else
96+
val = (val > n_embd_k_gqa(il)) ? val : n_embd_k_gqa(il);
97+
#endif
98+
}
99+
100+
return val;
101+
}
102+
103+
uint32_t llama_hparams::n_embd_v_gqa_max() const {
104+
uint32_t val = n_embd_v_gqa();
105+
for (uint32_t il = 0; il < n_layer; ++il) {
106+
#ifndef COSMOCC
107+
val = std::max(val, n_embd_v_gqa(il));
108+
#else
109+
val = (val > n_embd_v_gqa(il)) ? val : n_embd_v_gqa(il);
110+
#endif
111+
}
112+
113+
return val;
114+
}
115+
116+
uint32_t llama_hparams::n_embd_r() const {
117+
if (wkv_head_size != 0) {
118+
// for RWKV models
119+
return token_shift_count * n_embd;
120+
}
121+
122+
if (n_shortconv_l_cache != 0) {
123+
// for LFM2 models
124+
return n_embd * (n_shortconv_l_cache - 1);
125+
}
126+
127+
// TODO: maybe support other convolution strides than 1
128+
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
129+
// Corresponds to Mamba's conv_states size
130+
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
131+
}
132+
133+
uint32_t llama_hparams::n_embd_s() const {
134+
if (wkv_head_size != 0) {
135+
// corresponds to RWKV's wkv_states size
136+
return n_embd * wkv_head_size;
137+
}
138+
139+
// corresponds to Mamba's ssm_states size
140+
return ssm_d_state * ssm_d_inner;
141+
}
142+
143+
bool llama_hparams::is_recurrent(uint32_t il) const {
144+
return recurrent_layer_arr[il];
145+
}
146+
147+
uint32_t llama_hparams::n_pos_per_embd() const {
148+
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
149+
}
150+
151+
bool llama_hparams::is_swa(uint32_t il) const {
152+
if (il < n_layer) {
153+
return swa_layers[il];
154+
}
155+
156+
GGML_ABORT("fatal error");
157+
}

0 commit comments

Comments
 (0)