Skip to content

Commit 5eae8e5

Browse files
committed
context : move build_rope_factors to base class
ggml-ci
1 parent d146a14 commit 5eae8e5

File tree

3 files changed

+104
-101
lines changed

3 files changed

+104
-101
lines changed

src/llama-context.cpp

Lines changed: 88 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ uint32_t llama_context::n_ctx() const {
5757
return cparams.n_ctx;
5858
}
5959

60+
uint32_t llama_context::n_ctx_per_seq() const {
61+
return cparams.n_ctx / cparams.n_seq_max;
62+
}
63+
6064
uint32_t llama_context::n_batch() const {
6165
return cparams.n_batch;
6266
}
@@ -122,8 +126,8 @@ void llama_context::synchronize() {
122126
}
123127

124128
void llama_context::attach_threadpool(
125-
ggml_threadpool_t threadpool,
126-
ggml_threadpool_t threadpool_batch) {
129+
ggml_threadpool_t threadpool,
130+
ggml_threadpool_t threadpool_batch) {
127131
this->threadpool = threadpool;
128132
this->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
129133
}
@@ -202,6 +206,86 @@ llama_perf_context_data llama_context::perf_get_data() const {
202206
return data;
203207
}
204208

209+
ggml_tensor * llama_context::build_cvec(
210+
ggml_context * ctx0,
211+
ggml_tensor * cur,
212+
int il) {
213+
return cvec.apply_to(ctx0, cur, il);
214+
}
215+
216+
ggml_tensor * llama_context::build_lora_mm(
217+
ggml_context * ctx0,
218+
ggml_tensor * w,
219+
ggml_tensor * cur) {
220+
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
221+
222+
for (const auto & lora : loras) {
223+
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
224+
if (lw == nullptr) {
225+
continue;
226+
}
227+
228+
const float adapter_scale = lora.second;
229+
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
230+
231+
struct ggml_tensor * ab_cur = ggml_mul_mat(
232+
ctx0, lw->b,
233+
ggml_mul_mat(ctx0, lw->a, cur)
234+
);
235+
236+
ab_cur = ggml_scale(ctx0, ab_cur, scale);
237+
res = ggml_add(ctx0, res, ab_cur);
238+
}
239+
240+
return res;
241+
}
242+
243+
ggml_tensor * llama_context::build_lora_mm_id(
244+
ggml_context * ctx0,
245+
ggml_tensor * w,
246+
ggml_tensor * cur,
247+
ggml_tensor * ids) {
248+
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
249+
for (const auto & lora : loras) {
250+
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
251+
if (lw == nullptr) {
252+
continue;
253+
}
254+
255+
const float alpha = lora.first->alpha;
256+
const float rank = (float) lw->b->ne[0];
257+
const float scale = alpha ? lora.second * alpha / rank : lora.second;
258+
259+
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
260+
ctx0, lw->b,
261+
ggml_mul_mat_id(ctx0, lw->a, cur, ids),
262+
ids
263+
);
264+
265+
ab_cur = ggml_scale(ctx0, ab_cur, scale);
266+
res = ggml_add(ctx0, res, ab_cur);
267+
}
268+
269+
return res;
270+
}
271+
272+
ggml_tensor * llama_context::build_rope_factors(int il) {
273+
const auto & hparams = model.hparams;
274+
275+
// choose long/short freq factors based on the context size
276+
const auto n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
277+
278+
if (model.layers[il].rope_freqs != nullptr) {
279+
return model.layers[il].rope_freqs;
280+
}
281+
282+
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
283+
return model.layers[il].rope_long;
284+
}
285+
286+
return model.layers[il].rope_short;
287+
}
288+
205289
void llama_context::perf_reset() {
206290
t_start_us = ggml_time_us();
207291
t_eval_us = n_eval = 0;
@@ -217,7 +301,7 @@ llama_context_unified::llama_context_unified(
217301
const llama_context_params & params,
218302
build_graph_callback && cb_build_graph) :
219303
llama_context(model),
220-
cb_build_graph(std::move(cb_build_graph)){
304+
cb_build_graph(std::move(cb_build_graph)) {
221305

222306
const auto & hparams = model.hparams;
223307

@@ -1825,69 +1909,6 @@ size_t llama_context_unified::reserve_outputs(size_t n_outputs) {
18251909
return n_outputs_max;
18261910
}
18271911

1828-
ggml_tensor * llama_context::build_cvec(
1829-
ggml_context * ctx0,
1830-
ggml_tensor * cur,
1831-
int il) {
1832-
return cvec.apply_to(ctx0, cur, il);
1833-
}
1834-
1835-
ggml_tensor * llama_context::build_lora_mm(
1836-
ggml_context * ctx0,
1837-
ggml_tensor * w,
1838-
ggml_tensor * cur) {
1839-
struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
1840-
1841-
for (const auto & lora : loras) {
1842-
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
1843-
if (lw == nullptr) {
1844-
continue;
1845-
}
1846-
1847-
const float adapter_scale = lora.second;
1848-
const float scale = lw->get_scale(lora.first->alpha, adapter_scale);
1849-
1850-
struct ggml_tensor * ab_cur = ggml_mul_mat(
1851-
ctx0, lw->b,
1852-
ggml_mul_mat(ctx0, lw->a, cur)
1853-
);
1854-
1855-
ab_cur = ggml_scale(ctx0, ab_cur, scale);
1856-
res = ggml_add(ctx0, res, ab_cur);
1857-
}
1858-
1859-
return res;
1860-
}
1861-
1862-
ggml_tensor * llama_context::build_lora_mm_id(
1863-
ggml_context * ctx0,
1864-
ggml_tensor * w,
1865-
ggml_tensor * cur,
1866-
ggml_tensor * ids) {
1867-
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
1868-
for (const auto & lora : loras) {
1869-
struct llama_adapter_lora_weight * lw = lora.first->get_weight(w);
1870-
if (lw == nullptr) {
1871-
continue;
1872-
}
1873-
1874-
const float alpha = lora.first->alpha;
1875-
const float rank = (float) lw->b->ne[0];
1876-
const float scale = alpha ? lora.second * alpha / rank : lora.second;
1877-
1878-
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
1879-
ctx0, lw->b,
1880-
ggml_mul_mat_id(ctx0, lw->a, cur, ids),
1881-
ids
1882-
);
1883-
1884-
ab_cur = ggml_scale(ctx0, ab_cur, scale);
1885-
res = ggml_add(ctx0, res, ab_cur);
1886-
}
1887-
1888-
return res;
1889-
}
1890-
18911912
void llama_context_unified::kv_self_update() {
18921913
auto & kv = kv_self;
18931914

@@ -2189,23 +2210,6 @@ ggml_tensor * llama_context_unified::build_soft_max_ext(
21892210
return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias);
21902211
}
21912212

2192-
ggml_tensor * llama_context_unified::get_rope_factors(int il) {
2193-
const auto & hparams = model.hparams;
2194-
2195-
// choose long/short freq factors based on the context size
2196-
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
2197-
2198-
if (model.layers[il].rope_freqs != nullptr) {
2199-
return model.layers[il].rope_freqs;
2200-
}
2201-
2202-
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
2203-
return model.layers[il].rope_long;
2204-
}
2205-
2206-
return model.layers[il].rope_short;
2207-
}
2208-
22092213
ggml_tensor * llama_context_unified::build_inp_embd(
22102214
ggml_context * ctx0,
22112215
ggml_tensor * tok_embd,
@@ -2327,7 +2331,7 @@ void llama_context_unified::build_k_shift(
23272331
const int64_t n_head_kv = hparams.n_head_kv(il);
23282332
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
23292333

2330-
struct ggml_tensor * rope_factors = get_rope_factors(il);
2334+
struct ggml_tensor * rope_factors = build_rope_factors(il);
23312335

23322336
struct ggml_tensor * k =
23332337
ggml_view_3d(ctx0, kv_self.k_l[il],

src/llama-context.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ struct llama_context {
2323
const llama_model & get_model() const;
2424
const llama_cparams & get_cparams() const;
2525

26-
virtual uint32_t n_ctx() const;
27-
virtual uint32_t n_batch() const;
28-
virtual uint32_t n_ubatch() const;
29-
virtual uint32_t n_seq_max() const = 0;
26+
virtual uint32_t n_ctx() const;
27+
virtual uint32_t n_ctx_per_seq() const;
28+
virtual uint32_t n_batch() const;
29+
virtual uint32_t n_ubatch() const;
30+
virtual uint32_t n_seq_max() const = 0;
3031

3132
virtual uint32_t n_threads() const;
3233
virtual uint32_t n_threads_batch() const;
@@ -126,6 +127,8 @@ struct llama_context {
126127
ggml_tensor * cur, // struct ggml_tensor * b
127128
ggml_tensor * ids);
128129

130+
virtual ggml_tensor * build_rope_factors(int il);
131+
129132
// graph build API (context-specific)
130133

131134
virtual ggml_tensor * build_inp_embd(
@@ -182,8 +185,6 @@ struct llama_context {
182185
ggml_tensor * kq,
183186
float kq_scale) = 0;
184187

185-
virtual ggml_tensor * get_rope_factors(int il) = 0;
186-
187188
virtual void build_k_shift(
188189
ggml_context * ctx0,
189190
ggml_cgraph * graph) = 0;
@@ -342,7 +343,7 @@ class llama_context_unified : public llama_context {
342343
public:
343344
struct batch_manager;
344345

345-
// TODO: tmp until llama-model starts implementing the graph build function
346+
// TODO: tmp until llama_model starts implementing the graph build function
346347
typedef std::function<ggml_cgraph *(llama_context &, const llama_ubatch &, bool worst_case)> build_graph_callback;
347348

348349
llama_context_unified(
@@ -496,8 +497,6 @@ class llama_context_unified : public llama_context {
496497
ggml_tensor * kq,
497498
float kq_scale) override;
498499

499-
virtual ggml_tensor * get_rope_factors(int il) override;
500-
501500
virtual void build_k_shift(
502501
ggml_context * ctx0,
503502
ggml_cgraph * graph) override;
@@ -601,7 +600,7 @@ class llama_context_unified : public llama_context {
601600
virtual size_t state_get_data( uint8_t * dst, size_t size) override;
602601
virtual size_t state_set_data(const uint8_t * src, size_t size) override;
603602

604-
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
603+
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
605604
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
606605
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
607606

src/llama.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ struct llm_build_context {
685685
// self-attention
686686
{
687687
// rope freq factors for llama3; may return nullptr for llama2 and other models
688-
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
688+
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
689689

690690
// compute Q and K and RoPE them
691691
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -857,7 +857,7 @@ struct llm_build_context {
857857
} else if (n_head > 0) {
858858
// self-attention
859859
// rope freq factors for llama3; may return nullptr for llama2 and other models
860-
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
860+
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
861861

862862
// compute Q and K and RoPE them
863863
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -2999,7 +2999,7 @@ struct llm_build_context {
29992999
// self-attention
30003000
{
30013001
// rope freq factors for 128k context
3002-
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
3002+
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
30033003

30043004
struct ggml_tensor* attn_norm_output = build_norm(inpL,
30053005
model.layers[il].attn_norm,
@@ -3706,7 +3706,7 @@ struct llm_build_context {
37063706
for (int il = 0; il < n_layer; ++il) {
37073707
struct ggml_tensor * inpSA = inpL;
37083708

3709-
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
3709+
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
37103710
// norm
37113711
cur = build_norm(inpL,
37123712
model.layers[il].attn_norm, NULL,
@@ -4480,7 +4480,7 @@ struct llm_build_context {
44804480
// self-attention
44814481
{
44824482
// rope freq factors for 128k context
4483-
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
4483+
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
44844484

44854485
// compute Q and K and RoPE them
44864486
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -5373,7 +5373,7 @@ struct llm_build_context {
53735373
// self-attention
53745374
{
53755375
// rope freq factors for llama3; may return nullptr for llama2 and other models
5376-
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
5376+
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
53775377

53785378
// compute Q and K and RoPE them
53795379
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -6572,7 +6572,7 @@ struct llm_build_context {
65726572
// self-attention
65736573
{
65746574
// rope freq factors for llama3; may return nullptr for llama2 and other models
6575-
struct ggml_tensor * rope_factors = lctx.get_rope_factors(il);
6575+
struct ggml_tensor * rope_factors = lctx.build_rope_factors(il);
65766576

65776577
// compute Q and K and RoPE them
65786578
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);

0 commit comments

Comments
 (0)