1010struct  ggml_cgraph ;
1111struct  ggml_context ;
1212struct  ggml_tensor ;
13- struct  ggml_backend_buffer ;
1413
1514struct  llama_ubatch ;
1615
16+ //  certain models (typically multi-modal) can produce different types of graphs
17+ //  the llama_context specifies which type of graph it needs through the llama_graph_i::type member
1718enum  llama_graph_type {
1819    LLAMA_GRAPH_TYPE_DEFAULT,
1920    LLAMA_GRAPH_TYPE_ENCODER,
2021    LLAMA_GRAPH_TYPE_DECODER,
2122};
2223
24+ 
2325// 
2426//  llama_graph_input
2527// 
2628
29+ //  denotes an input to the graph
30+ //  typically, the data of these objects is populated based on the contents of the current llama_ubatch:
31+ // 
32+ //   - llama_graph_input_pos
33+ //   - llama_graph_input_out_ids
34+ //   - etc.
35+ // 
36+ //  some inputs require context-specific data (e.g. KV cache) - such inputs are defined for the specific llama_context:
37+ // 
38+ //   - llama_graph_input_embd         (can apply lora)
39+ //   - llama_graph_input_attn_kv_self (requires KV cache instance)
40+ //   - etc.
41+ // 
42+ 
2743class  llama_graph_input_i  {
2844public: 
2945    virtual  ~llama_graph_input_i () = default ;
3046
3147    virtual  void  set_input (const  llama_ubatch * ubatch) = 0;
3248
33-     //  by default, we produce a single input tensor, but some children  could produce more
49+     //  by default, we produce a single input tensor, but some implementations  could produce more
3450    ggml_tensor * cur = nullptr ;
3551};
3652
3753using  llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
3854
55+ 
3956class  llama_graph_input_attn_i  : public  llama_graph_input_i  {
4057public: 
4158    virtual  ~llama_graph_input_attn_i () = default ;
@@ -47,10 +64,17 @@ class llama_graph_input_attn_i : public llama_graph_input_i {
4764
4865using  llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
4966
67+ 
5068// 
5169//  llama_graph_result
5270// 
5371
72+ //  these objects deliver the result from the graph build process back to the llama_context
73+ //  note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
74+ //    specific data, by calling the set_inputs() method
75+ //  along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
76+ //    these are used by the llama_context to extact the relevant data, based on the compute parameters
77+ 
5478class  llama_graph_result_i  {
5579public: 
5680    virtual  ~llama_graph_result_i () = default ;
@@ -64,9 +88,9 @@ class llama_graph_result_i {
6488
6589using  llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
6690
91+ 
6792class  llama_graph_result  : public  llama_graph_result_i  {
6893public: 
69-     llama_graph_result ()          = default ;
7094    virtual  ~llama_graph_result () = default ;
7195
7296    ggml_tensor * get_logits ()      override  { return  t_logits; }
@@ -91,10 +115,19 @@ class llama_graph_result : public llama_graph_result_i {
91115    std::vector<llama_graph_input_ptr> inputs;
92116};
93117
118+ 
94119// 
95120//  llama_graph
96121// 
97122
123+ //  this interface defines an API for building graphs by abstracting some high-level concepts such as attention, lora, etc.
124+ //  functionality that is trivial and does not rely on the llama_context should be directly implemented in llm_build_context
125+ //    other context-specific functionality should be declared here and implemented in the llama_context variations
126+ // 
127+ //  the main goal of this interface is to separate the llama_context specifics from the graph building logic
128+ //    this allows to have cleaner model architecture definitions while being able to overload certain complex
129+ //    functionality in order to fit different use cases and/or explore new implementations and ideas
130+ 
98131//  note: keep all methods const
99132//  TODO: can become more granular in the future
100133class  llama_graph_i  {
@@ -112,6 +145,10 @@ class llama_graph_i {
112145public: 
113146    virtual  int32_t  get_n_outputs () const  = 0;
114147
148+     // 
149+     //  context-specific API
150+     // 
151+ 
115152    //  callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
116153    virtual  void  build_cb (
117154             ggml_tensor * cur,
@@ -141,8 +178,6 @@ class llama_graph_i {
141178    //  rope factors based on the current context size
142179    virtual  ggml_tensor * build_rope_factors (int  il) const  = 0;
143180
144-     //  graph build API (context-specific)
145- 
146181    //  input embeddings with optional lora
147182    virtual  llama_graph_input_ptr build_inp_embd (
148183            ggml_context * ctx0,
@@ -154,6 +189,9 @@ class llama_graph_i {
154189            ggml_context * ctx0,
155190                 int32_t    n_tokens) const  = 0;
156191
192+     virtual  llama_graph_input_ptr build_inp_cross_embd (
193+             ggml_context * ctx0) const ;
194+ 
157195    // 
158196    //  attention API
159197    // 
@@ -186,9 +224,6 @@ class llama_graph_i {
186224                 float      kq_scale,
187225                 int        il) const ;
188226
189-     virtual  llama_graph_input_ptr build_inp_cross_embd (
190-             ggml_context * ctx0) const ;
191- 
192227    // 
193228    //  recurrent API
194229    // 
0 commit comments