Skip to content

Commit e633dc1

Browse files
committed
context : introduce llama_graph_i
ggml-ci
1 parent 5eae8e5 commit e633dc1

File tree

4 files changed

+168
-132
lines changed

4 files changed

+168
-132
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_library(llama
1515
llama-chat.cpp
1616
llama-context.cpp
1717
llama-grammar.cpp
18+
llama-graph.cpp
1819
llama-hparams.cpp
1920
llama-impl.cpp
2021
llama-kv-cache.cpp

src/llama-context.h

Lines changed: 2 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama.h"
44
#include "llama-batch.h"
55
#include "llama-cparams.h"
6+
#include "llama-graph.h"
67
#include "llama-model.h"
78
#include "llama-kv-cache.h"
89
#include "llama-adapter.h"
@@ -16,7 +17,7 @@
1617

1718
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
1819

19-
struct llama_context {
20+
struct llama_context : public llama_graph_i {
2021
llama_context(const llama_model & model);
2122
virtual ~llama_context();
2223

@@ -129,137 +130,6 @@ struct llama_context {
129130

130131
virtual ggml_tensor * build_rope_factors(int il);
131132

132-
// graph build API (context-specific)
133-
134-
virtual ggml_tensor * build_inp_embd(
135-
ggml_context * ctx0,
136-
ggml_tensor * tok_embd,
137-
const llama_ubatch & ubatch) = 0;
138-
139-
virtual ggml_tensor * build_inp_pos(
140-
ggml_context * ctx0,
141-
int32_t n_tokens) = 0;
142-
143-
virtual ggml_tensor * build_inp_out_ids(
144-
ggml_context * ctx0,
145-
int32_t n_tokens,
146-
bool worst_case) = 0;
147-
148-
virtual ggml_tensor * build_inp_mean(
149-
ggml_context * ctx0,
150-
int32_t n_tokens) = 0;
151-
152-
virtual ggml_tensor * build_inp_cls(
153-
ggml_context * ctx0,
154-
int32_t n_tokens) = 0;
155-
156-
virtual void build_attn_inp(
157-
ggml_context * ctx0,
158-
int32_t n_tokens,
159-
bool causal,
160-
bool swa,
161-
bool worst_case) = 0;
162-
163-
virtual void build_attn_kv_store(
164-
ggml_context * ctx0,
165-
ggml_cgraph * graph,
166-
ggml_tensor * k_cur,
167-
ggml_tensor * v_cur,
168-
int32_t n_tokens,
169-
int64_t il,
170-
bool worst_case) = 0;
171-
172-
virtual ggml_tensor * build_attn_qkv(
173-
ggml_context * ctx0,
174-
ggml_cgraph * graph,
175-
ggml_tensor * wo,
176-
ggml_tensor * wo_b,
177-
ggml_tensor * q_cur,
178-
int32_t n_tokens,
179-
float kq_scale,
180-
int il,
181-
bool worst_case) = 0;
182-
183-
virtual ggml_tensor * build_soft_max_ext(
184-
ggml_context * ctx0,
185-
ggml_tensor * kq,
186-
float kq_scale) = 0;
187-
188-
virtual void build_k_shift(
189-
ggml_context * ctx0,
190-
ggml_cgraph * graph) = 0;
191-
192-
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
193-
virtual void build_defrag(
194-
ggml_context * ctx0,
195-
ggml_cgraph * graph) = 0;
196-
197-
virtual ggml_tensor * build_inp_embd_enc(
198-
ggml_context * ctx0,
199-
int32_t n_tokens,
200-
bool worst_case) = 0;
201-
202-
virtual ggml_tensor * build_inp_KQ_mask_cross(
203-
ggml_context * ctx0,
204-
int32_t n_tokens,
205-
bool worst_case) = 0;
206-
207-
virtual ggml_tensor * build_inp_s_copy(
208-
ggml_context * ctx0,
209-
bool worst_case) = 0;
210-
211-
virtual ggml_tensor * build_inp_s_mask(
212-
ggml_context * ctx0,
213-
bool worst_case) = 0;
214-
215-
virtual ggml_tensor * build_copy_mask_state(
216-
ggml_context * ctx0,
217-
ggml_cgraph * graph,
218-
ggml_tensor * s,
219-
ggml_tensor * state_copy,
220-
ggml_tensor * state_mask,
221-
int32_t n_tokens,
222-
int32_t n_state,
223-
int32_t n_seqs,
224-
bool worst_case) = 0;
225-
226-
virtual ggml_tensor * build_mamba_layer(
227-
ggml_context * ctx0,
228-
ggml_cgraph * graph,
229-
ggml_tensor * cur,
230-
ggml_tensor * state_copy,
231-
ggml_tensor * state_mask,
232-
const llama_ubatch & ubatch,
233-
int il,
234-
bool worst_case) = 0;
235-
236-
virtual ggml_tensor * build_rwkv_token_shift_load(
237-
ggml_context * ctx0,
238-
ggml_cgraph * graph,
239-
ggml_tensor * state_copy,
240-
ggml_tensor * state_mask,
241-
const llama_ubatch & ubatch,
242-
int il,
243-
bool worst_case) = 0;
244-
245-
virtual ggml_tensor * build_rwkv_token_shift_store(
246-
ggml_context * ctx0,
247-
ggml_tensor * token_shift,
248-
const llama_ubatch & ubatch,
249-
int il,
250-
bool worst_case) = 0;
251-
252-
virtual ggml_tensor * build_rwkv6_time_mix(
253-
ggml_context * ctx0,
254-
ggml_cgraph * graph,
255-
ggml_tensor * cur,
256-
ggml_tensor * x_prev,
257-
ggml_tensor * state_copy,
258-
ggml_tensor * state_mask,
259-
const llama_ubatch & ubatch,
260-
int il,
261-
bool worst_case) = 0;
262-
263133
// state save/load
264134

265135
virtual size_t state_get_size() = 0;

src/llama-graph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "llama-graph.h"

src/llama-graph.h

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#pragma once
2+
3+
#include <cstdint>
4+
5+
struct ggml_cgraph;
6+
struct ggml_context;
7+
struct ggml_tensor;
8+
struct llama_ubatch;
9+
10+
// TODO: pass to llama_model graph build
11+
class llama_graph_i {
12+
public:
13+
// apply control vector for layer il
14+
virtual ggml_tensor * build_cvec(
15+
ggml_context * ctx0,
16+
ggml_tensor * cur,
17+
int il) = 0;
18+
19+
// do mat_mul, while optionally apply lora
20+
virtual ggml_tensor * build_lora_mm(
21+
ggml_context * ctx0,
22+
ggml_tensor * w,
23+
ggml_tensor * cur) = 0;
24+
25+
// do mat_mul_id, while optionally apply lora
26+
virtual ggml_tensor * build_lora_mm_id(
27+
ggml_context * ctx0,
28+
ggml_tensor * w, // struct ggml_tensor * as
29+
ggml_tensor * cur, // struct ggml_tensor * b
30+
ggml_tensor * ids) = 0;
31+
32+
virtual ggml_tensor * build_rope_factors(int il) = 0;
33+
34+
// graph build API (context-specific)
35+
36+
virtual ggml_tensor * build_inp_embd(
37+
ggml_context * ctx0,
38+
ggml_tensor * tok_embd,
39+
const llama_ubatch & ubatch) = 0;
40+
41+
virtual ggml_tensor * build_inp_pos(
42+
ggml_context * ctx0,
43+
int32_t n_tokens) = 0;
44+
45+
virtual ggml_tensor * build_inp_out_ids(
46+
ggml_context * ctx0,
47+
int32_t n_tokens,
48+
bool worst_case) = 0;
49+
50+
virtual ggml_tensor * build_inp_mean(
51+
ggml_context * ctx0,
52+
int32_t n_tokens) = 0;
53+
54+
virtual ggml_tensor * build_inp_cls(
55+
ggml_context * ctx0,
56+
int32_t n_tokens) = 0;
57+
58+
virtual void build_attn_inp(
59+
ggml_context * ctx0,
60+
int32_t n_tokens,
61+
bool causal,
62+
bool swa,
63+
bool worst_case) = 0;
64+
65+
virtual void build_attn_kv_store(
66+
ggml_context * ctx0,
67+
ggml_cgraph * graph,
68+
ggml_tensor * k_cur,
69+
ggml_tensor * v_cur,
70+
int32_t n_tokens,
71+
int64_t il,
72+
bool worst_case) = 0;
73+
74+
virtual ggml_tensor * build_attn_qkv(
75+
ggml_context * ctx0,
76+
ggml_cgraph * graph,
77+
ggml_tensor * wo,
78+
ggml_tensor * wo_b,
79+
ggml_tensor * q_cur,
80+
int32_t n_tokens,
81+
float kq_scale,
82+
int il,
83+
bool worst_case) = 0;
84+
85+
virtual ggml_tensor * build_soft_max_ext(
86+
ggml_context * ctx0,
87+
ggml_tensor * kq,
88+
float kq_scale) = 0;
89+
90+
virtual void build_k_shift(
91+
ggml_context * ctx0,
92+
ggml_cgraph * graph) = 0;
93+
94+
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
95+
virtual void build_defrag(
96+
ggml_context * ctx0,
97+
ggml_cgraph * graph) = 0;
98+
99+
virtual ggml_tensor * build_inp_embd_enc(
100+
ggml_context * ctx0,
101+
int32_t n_tokens,
102+
bool worst_case) = 0;
103+
104+
virtual ggml_tensor * build_inp_KQ_mask_cross(
105+
ggml_context * ctx0,
106+
int32_t n_tokens,
107+
bool worst_case) = 0;
108+
109+
virtual ggml_tensor * build_inp_s_copy(
110+
ggml_context * ctx0,
111+
bool worst_case) = 0;
112+
113+
virtual ggml_tensor * build_inp_s_mask(
114+
ggml_context * ctx0,
115+
bool worst_case) = 0;
116+
117+
virtual ggml_tensor * build_copy_mask_state(
118+
ggml_context * ctx0,
119+
ggml_cgraph * graph,
120+
ggml_tensor * s,
121+
ggml_tensor * state_copy,
122+
ggml_tensor * state_mask,
123+
int32_t n_tokens,
124+
int32_t n_state,
125+
int32_t n_seqs,
126+
bool worst_case) = 0;
127+
128+
virtual ggml_tensor * build_mamba_layer(
129+
ggml_context * ctx0,
130+
ggml_cgraph * graph,
131+
ggml_tensor * cur,
132+
ggml_tensor * state_copy,
133+
ggml_tensor * state_mask,
134+
const llama_ubatch & ubatch,
135+
int il,
136+
bool worst_case) = 0;
137+
138+
virtual ggml_tensor * build_rwkv_token_shift_load(
139+
ggml_context * ctx0,
140+
ggml_cgraph * graph,
141+
ggml_tensor * state_copy,
142+
ggml_tensor * state_mask,
143+
const llama_ubatch & ubatch,
144+
int il,
145+
bool worst_case) = 0;
146+
147+
virtual ggml_tensor * build_rwkv_token_shift_store(
148+
ggml_context * ctx0,
149+
ggml_tensor * token_shift,
150+
const llama_ubatch & ubatch,
151+
int il,
152+
bool worst_case) = 0;
153+
154+
virtual ggml_tensor * build_rwkv6_time_mix(
155+
ggml_context * ctx0,
156+
ggml_cgraph * graph,
157+
ggml_tensor * cur,
158+
ggml_tensor * x_prev,
159+
ggml_tensor * state_copy,
160+
ggml_tensor * state_mask,
161+
const llama_ubatch & ubatch,
162+
int il,
163+
bool worst_case) = 0;
164+
};

0 commit comments

Comments
 (0)