Skip to content

Commit bc6f187

Browse files
committed
cont : use returend tensors from the graph build
ggml-ci
1 parent 172f616 commit bc6f187

File tree

1 file changed

+13
-47
lines changed

1 file changed

+13
-47
lines changed

src/llama-context.cpp

Lines changed: 13 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,37 +1855,14 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
18551855
auto ctx = graph_init();
18561856
auto res = graph_build(ctx, ubatch, false);
18571857

1858-
auto & gf = res.gf;
1858+
auto * gf = res.gf;
18591859

18601860
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
18611861

18621862
ggml_backend_sched_alloc_graph(sched.get(), gf);
18631863

18641864
input_set(ubatch);
18651865

1866-
// the output is always the last tensor in the graph
1867-
struct ggml_tensor * t_logits = ggml_graph_node(gf, -1);
1868-
struct ggml_tensor * t_embd = ggml_graph_node(gf, -2);
1869-
1870-
if (n_outputs == 0) {
1871-
// no output
1872-
t_logits = nullptr;
1873-
t_embd = nullptr;
1874-
} else if (cparams.embeddings) {
1875-
t_logits = nullptr; // do not extract logits for embedding case
1876-
t_embd = nullptr;
1877-
for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
1878-
if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
1879-
t_embd = ggml_graph_node(gf, i);
1880-
break;
1881-
}
1882-
}
1883-
GGML_ASSERT(t_embd != nullptr && "missing embeddings tensor");
1884-
} else {
1885-
t_embd = nullptr; // do not extract embeddings when not needed
1886-
GGML_ASSERT(strcmp(t_logits->name, "result_output") == 0 && "missing result_output tensor");
1887-
}
1888-
18891866
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
18901867
if (compute_status != GGML_STATUS_SUCCESS) {
18911868
switch (compute_status) {
@@ -1914,8 +1891,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
19141891
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
19151892
//}
19161893

1894+
auto * t_logits = cparams.embeddings ? nullptr : res.t_logits;
1895+
auto * t_embd = cparams.embeddings ? res.t_embd : nullptr;
1896+
1897+
if (t_embd && res.t_embd_pooled) {
1898+
t_embd = res.t_embd_pooled;
1899+
}
1900+
19171901
// extract logits
1918-
if (t_logits) {
1902+
if (t_logits && n_outputs > 0) {
19191903
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
19201904
GGML_ASSERT(backend_res != nullptr);
19211905
GGML_ASSERT(logits != nullptr);
@@ -1930,7 +1914,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
19301914
}
19311915

19321916
// extract embeddings
1933-
if (t_embd) {
1917+
if (t_embd && n_outputs > 0) {
19341918
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
19351919
GGML_ASSERT(backend_embd != nullptr);
19361920

@@ -2103,32 +2087,12 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
21032087
auto ctx = graph_init();
21042088
auto res = graph_build(ctx, ubatch, false);
21052089

2106-
auto & gf = res.gf;
2090+
auto * gf = res.gf;
21072091

21082092
ggml_backend_sched_alloc_graph(sched.get(), gf);
21092093

21102094
input_set(ubatch);
21112095

2112-
// the output embeddings after the final encoder normalization
2113-
struct ggml_tensor * t_embd = nullptr;
2114-
2115-
// there are two cases here
2116-
if (llama_model_has_decoder(&model)) {
2117-
// first case is an encoder-decoder T5 model where embeddings are passed to decoder
2118-
t_embd = ggml_graph_node(gf, -1);
2119-
GGML_ASSERT(strcmp(t_embd->name, "result_norm") == 0 && "missing result_output tensor");
2120-
} else {
2121-
// second case is an encoder-only T5 model
2122-
if (cparams.embeddings) {
2123-
// only output embeddings if required
2124-
t_embd = ggml_graph_node(gf, -1);
2125-
if (strcmp(t_embd->name, "result_embd_pooled") != 0) {
2126-
t_embd = ggml_graph_node(gf, -2);
2127-
}
2128-
GGML_ASSERT(strcmp(t_embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
2129-
}
2130-
}
2131-
21322096
const auto compute_status = graph_compute(gf, n_tokens > 1);
21332097
switch (compute_status) {
21342098
case GGML_STATUS_SUCCESS:
@@ -2142,6 +2106,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
21422106
return -3;
21432107
}
21442108

2109+
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
2110+
21452111
// extract embeddings
21462112
if (t_embd) {
21472113
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);

0 commit comments

Comments
 (0)