@@ -1855,37 +1855,14 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
1855
1855
auto ctx = graph_init ();
1856
1856
auto res = graph_build (ctx, ubatch, false );
1857
1857
1858
- auto & gf = res.gf ;
1858
+ auto * gf = res.gf ;
1859
1859
1860
1860
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
1861
1861
1862
1862
ggml_backend_sched_alloc_graph (sched.get (), gf);
1863
1863
1864
1864
input_set (ubatch);
1865
1865
1866
- // the output is always the last tensor in the graph
1867
- struct ggml_tensor * t_logits = ggml_graph_node (gf, -1 );
1868
- struct ggml_tensor * t_embd = ggml_graph_node (gf, -2 );
1869
-
1870
- if (n_outputs == 0 ) {
1871
- // no output
1872
- t_logits = nullptr ;
1873
- t_embd = nullptr ;
1874
- } else if (cparams.embeddings ) {
1875
- t_logits = nullptr ; // do not extract logits for embedding case
1876
- t_embd = nullptr ;
1877
- for (int i = ggml_graph_n_nodes (gf) - 1 ; i >= 0 ; --i) {
1878
- if (strcmp (ggml_graph_node (gf, i)->name , " result_embd_pooled" ) == 0 ) {
1879
- t_embd = ggml_graph_node (gf, i);
1880
- break ;
1881
- }
1882
- }
1883
- GGML_ASSERT (t_embd != nullptr && " missing embeddings tensor" );
1884
- } else {
1885
- t_embd = nullptr ; // do not extract embeddings when not needed
1886
- GGML_ASSERT (strcmp (t_logits->name , " result_output" ) == 0 && " missing result_output tensor" );
1887
- }
1888
-
1889
1866
const auto compute_status = graph_compute (gf, ubatch.n_tokens > 1 );
1890
1867
if (compute_status != GGML_STATUS_SUCCESS) {
1891
1868
switch (compute_status) {
@@ -1914,8 +1891,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
1914
1891
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
1915
1892
// }
1916
1893
1894
+ auto * t_logits = cparams.embeddings ? nullptr : res.t_logits ;
1895
+ auto * t_embd = cparams.embeddings ? res.t_embd : nullptr ;
1896
+
1897
+ if (t_embd && res.t_embd_pooled ) {
1898
+ t_embd = res.t_embd_pooled ;
1899
+ }
1900
+
1917
1901
// extract logits
1918
- if (t_logits) {
1902
+ if (t_logits && n_outputs > 0 ) {
1919
1903
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend (sched.get (), t_logits);
1920
1904
GGML_ASSERT (backend_res != nullptr );
1921
1905
GGML_ASSERT (logits != nullptr );
@@ -1930,7 +1914,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
1930
1914
}
1931
1915
1932
1916
// extract embeddings
1933
- if (t_embd) {
1917
+ if (t_embd && n_outputs > 0 ) {
1934
1918
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend (sched.get (), t_embd);
1935
1919
GGML_ASSERT (backend_embd != nullptr );
1936
1920
@@ -2103,32 +2087,12 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
2103
2087
auto ctx = graph_init ();
2104
2088
auto res = graph_build (ctx, ubatch, false );
2105
2089
2106
- auto & gf = res.gf ;
2090
+ auto * gf = res.gf ;
2107
2091
2108
2092
ggml_backend_sched_alloc_graph (sched.get (), gf);
2109
2093
2110
2094
input_set (ubatch);
2111
2095
2112
- // the output embeddings after the final encoder normalization
2113
- struct ggml_tensor * t_embd = nullptr ;
2114
-
2115
- // there are two cases here
2116
- if (llama_model_has_decoder (&model)) {
2117
- // first case is an encoder-decoder T5 model where embeddings are passed to decoder
2118
- t_embd = ggml_graph_node (gf, -1 );
2119
- GGML_ASSERT (strcmp (t_embd->name , " result_norm" ) == 0 && " missing result_output tensor" );
2120
- } else {
2121
- // second case is an encoder-only T5 model
2122
- if (cparams.embeddings ) {
2123
- // only output embeddings if required
2124
- t_embd = ggml_graph_node (gf, -1 );
2125
- if (strcmp (t_embd->name , " result_embd_pooled" ) != 0 ) {
2126
- t_embd = ggml_graph_node (gf, -2 );
2127
- }
2128
- GGML_ASSERT (strcmp (t_embd->name , " result_embd_pooled" ) == 0 && " missing embeddings tensor" );
2129
- }
2130
- }
2131
-
2132
2096
const auto compute_status = graph_compute (gf, n_tokens > 1 );
2133
2097
switch (compute_status) {
2134
2098
case GGML_STATUS_SUCCESS:
@@ -2142,6 +2106,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
2142
2106
return -3 ;
2143
2107
}
2144
2108
2109
+ auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd ;
2110
+
2145
2111
// extract embeddings
2146
2112
if (t_embd) {
2147
2113
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend (sched.get (), t_embd);
0 commit comments