Skip to content

Commit 624f7bd

Browse files
committed
graph : add comments
ggml-ci
1 parent 0f7daa9 commit 624f7bd

File tree

3 files changed

+52
-16
lines changed

3 files changed

+52
-16
lines changed

src/llama-context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ void llama_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
101101
}
102102
}
103103

104+
// note: this does not depend on the context and can technically be moved to llama-model.cpp
104105
class llama_graph_input_attn_base : public llama_graph_input_attn_i {
105106
public:
106107
llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) :

src/llama-graph.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() {
1919

2020
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
2121

22+
llama_graph_input_ptr llama_graph_i::build_inp_cross_embd(
23+
ggml_context * ctx0) const {
24+
GGML_UNUSED(ctx0);
25+
26+
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
27+
return nullptr;
28+
}
29+
2230
ggml_tensor * llama_graph_i::build_attn(
2331
llama_graph_input_attn_i * inp,
2432
ggml_context * ctx0,
@@ -67,14 +75,6 @@ ggml_tensor * llama_graph_i::build_attn_cross(
6775
return nullptr;
6876
}
6977

70-
llama_graph_input_ptr llama_graph_i::build_inp_cross_embd(
71-
ggml_context * ctx0) const {
72-
GGML_UNUSED(ctx0);
73-
74-
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
75-
return nullptr;
76-
}
77-
7878
llama_graph_input_ptr llama_graph_i::build_inp_s_copy (
7979
ggml_context * ctx0) const {
8080
GGML_UNUSED(ctx0);

src/llama-graph.h

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,49 @@
1010
struct ggml_cgraph;
1111
struct ggml_context;
1212
struct ggml_tensor;
13-
struct ggml_backend_buffer;
1413

1514
struct llama_ubatch;
1615

16+
// certain models (typically multi-modal) can produce different types of graphs
17+
// the llama_context specifies which type of graph it needs through the llama_graph_i::type member
1718
enum llama_graph_type {
1819
LLAMA_GRAPH_TYPE_DEFAULT,
1920
LLAMA_GRAPH_TYPE_ENCODER,
2021
LLAMA_GRAPH_TYPE_DECODER,
2122
};
2223

24+
2325
//
2426
// llama_graph_input
2527
//
2628

29+
// denotes an input to the graph
30+
// typically, the data of these objects is populated based on the contents of the current llama_ubatch:
31+
//
32+
// - llama_graph_input_pos
33+
// - llama_graph_input_out_ids
34+
// - etc.
35+
//
36+
// some inputs require context-specific data (e.g. KV cache) - such inputs are defined for the specific llama_context:
37+
//
38+
// - llama_graph_input_embd (can apply lora)
39+
// - llama_graph_input_attn_kv_self (requires KV cache instance)
40+
// - etc.
41+
//
42+
2743
class llama_graph_input_i {
2844
public:
2945
virtual ~llama_graph_input_i() = default;
3046

3147
virtual void set_input(const llama_ubatch * ubatch) = 0;
3248

33-
// by default, we produce a single input tensor, but some children could produce more
49+
// by default, we produce a single input tensor, but some implementations could produce more
3450
ggml_tensor * cur = nullptr;
3551
};
3652

3753
using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
3854

55+
3956
class llama_graph_input_attn_i : public llama_graph_input_i {
4057
public:
4158
virtual ~llama_graph_input_attn_i() = default;
@@ -47,10 +64,17 @@ class llama_graph_input_attn_i : public llama_graph_input_i {
4764

4865
using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
4966

67+
5068
//
5169
// llama_graph_result
5270
//
5371

72+
// these objects deliver the result from the graph build process back to the llama_context
73+
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
74+
// specific data, by calling the set_inputs() method
75+
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
76+
// these are used by the llama_context to extact the relevant data, based on the compute parameters
77+
5478
class llama_graph_result_i {
5579
public:
5680
virtual ~llama_graph_result_i() = default;
@@ -64,9 +88,9 @@ class llama_graph_result_i {
6488

6589
using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
6690

91+
6792
class llama_graph_result : public llama_graph_result_i {
6893
public:
69-
llama_graph_result() = default;
7094
virtual ~llama_graph_result() = default;
7195

7296
ggml_tensor * get_logits() override { return t_logits; }
@@ -91,10 +115,19 @@ class llama_graph_result : public llama_graph_result_i {
91115
std::vector<llama_graph_input_ptr> inputs;
92116
};
93117

118+
94119
//
95120
// llama_graph
96121
//
97122

123+
// this interface defines an API for building graphs by abstracting some high-level concepts such as attention, lora, etc.
124+
// functionality that is trivial and does not rely on the llama_context should be directly implemented in llm_build_context
125+
// other context-specific functionality should be declared here and implemented in the llama_context variations
126+
//
127+
// the main goal of this interface is to separate the llama_context specifics from the graph building logic
128+
// this allows to have cleaner model architecture definitions while being able to overload certain complex
129+
// functionality in order to fit different use cases and/or explore new implementations and ideas
130+
98131
// note: keep all methods const
99132
// TODO: can become more granular in the future
100133
class llama_graph_i {
@@ -112,6 +145,10 @@ class llama_graph_i {
112145
public:
113146
virtual int32_t get_n_outputs() const = 0;
114147

148+
//
149+
// context-specific API
150+
//
151+
115152
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
116153
virtual void build_cb(
117154
ggml_tensor * cur,
@@ -141,8 +178,6 @@ class llama_graph_i {
141178
// rope factors based on the current context size
142179
virtual ggml_tensor * build_rope_factors(int il) const = 0;
143180

144-
// graph build API (context-specific)
145-
146181
// input embeddings with optional lora
147182
virtual llama_graph_input_ptr build_inp_embd(
148183
ggml_context * ctx0,
@@ -154,6 +189,9 @@ class llama_graph_i {
154189
ggml_context * ctx0,
155190
int32_t n_tokens) const = 0;
156191

192+
virtual llama_graph_input_ptr build_inp_cross_embd(
193+
ggml_context * ctx0) const;
194+
157195
//
158196
// attention API
159197
//
@@ -186,9 +224,6 @@ class llama_graph_i {
186224
float kq_scale,
187225
int il) const;
188226

189-
virtual llama_graph_input_ptr build_inp_cross_embd(
190-
ggml_context * ctx0) const;
191-
192227
//
193228
// recurrent API
194229
//

0 commit comments

Comments
 (0)