Skip to content

Commit befe14f

Browse files
committed
llama : reorder encode/decode in sources
1 parent bc6f187 commit befe14f

File tree

2 files changed

+172
-172
lines changed

2 files changed

+172
-172
lines changed

src/llama-context.cpp

Lines changed: 162 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,6 +1655,168 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
16551655
return llama_context::graph_init();
16561656
}
16571657

1658+
int llama_context_kv_self::encode(llama_batch & inp_batch) {
1659+
is_encoding = true;
1660+
1661+
if (inp_batch.n_tokens == 0) {
1662+
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
1663+
return -1;
1664+
}
1665+
1666+
// temporary allocate memory for the input batch if needed
1667+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1668+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
1669+
1670+
const llama_batch & batch = batch_allocr.batch;
1671+
const int32_t n_tokens = batch.n_tokens;
1672+
1673+
const auto & hparams = model.hparams;
1674+
1675+
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1676+
1677+
if (batch.token) {
1678+
for (int32_t i = 0; i < n_tokens; ++i) {
1679+
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
1680+
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
1681+
return -1;
1682+
}
1683+
}
1684+
}
1685+
1686+
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
1687+
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
1688+
1689+
if (t_compute_start_us == 0) {
1690+
t_compute_start_us = ggml_time_us();
1691+
}
1692+
1693+
n_queued_tokens += n_tokens;
1694+
1695+
const int64_t n_embd = hparams.n_embd;
1696+
1697+
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
1698+
1699+
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
1700+
1701+
// reserve output buffer
1702+
if (output_reserve(n_tokens) < n_tokens) {
1703+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
1704+
return -2;
1705+
};
1706+
1707+
for (int32_t i = 0; i < n_tokens; ++i) {
1708+
output_ids[i] = i;
1709+
}
1710+
1711+
inp_embd_enc = NULL;
1712+
n_outputs = n_tokens;
1713+
1714+
//batch_manager->prepare(ubatch);
1715+
1716+
// TODO: do reserve
1717+
GGML_ASSERT(need_reserve == false);
1718+
1719+
ggml_backend_sched_reset(sched.get());
1720+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1721+
1722+
auto ctx = graph_init();
1723+
auto res = graph_build(ctx, ubatch, false);
1724+
1725+
auto * gf = res.gf;
1726+
1727+
ggml_backend_sched_alloc_graph(sched.get(), gf);
1728+
1729+
input_set(ubatch);
1730+
1731+
const auto compute_status = graph_compute(gf, n_tokens > 1);
1732+
switch (compute_status) {
1733+
case GGML_STATUS_SUCCESS:
1734+
break;
1735+
case GGML_STATUS_ABORTED:
1736+
return 2;
1737+
case GGML_STATUS_ALLOC_FAILED:
1738+
return -2;
1739+
case GGML_STATUS_FAILED:
1740+
default:
1741+
return -3;
1742+
}
1743+
1744+
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
1745+
1746+
// extract embeddings
1747+
if (t_embd) {
1748+
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
1749+
GGML_ASSERT(backend_embd != nullptr);
1750+
1751+
if (llama_model_has_decoder(&model)) {
1752+
embd_enc.resize(n_tokens*n_embd);
1753+
float * embd_out = embd_enc.data();
1754+
1755+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
1756+
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
1757+
1758+
// remember the sequence ids used during the encoding - needed for cross attention later
1759+
seq_ids_enc.resize(n_tokens);
1760+
for (int32_t i = 0; i < n_tokens; i++) {
1761+
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
1762+
llama_seq_id seq_id = ubatch.seq_id[i][s];
1763+
seq_ids_enc[i].insert(seq_id);
1764+
}
1765+
}
1766+
} else {
1767+
GGML_ASSERT(embd != nullptr);
1768+
1769+
switch (cparams.pooling_type) {
1770+
case LLAMA_POOLING_TYPE_NONE:
1771+
{
1772+
// extract token embeddings
1773+
GGML_ASSERT(embd != nullptr);
1774+
float * embd_out = embd;
1775+
1776+
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
1777+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
1778+
} break;
1779+
case LLAMA_POOLING_TYPE_MEAN:
1780+
case LLAMA_POOLING_TYPE_CLS:
1781+
case LLAMA_POOLING_TYPE_LAST:
1782+
{
1783+
// extract sequence embeddings
1784+
auto & embd_seq_out = embd_seq;
1785+
embd_seq_out.clear();
1786+
1787+
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
1788+
1789+
for (int32_t i = 0; i < n_tokens; i++) {
1790+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
1791+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1792+
continue;
1793+
}
1794+
embd_seq_out[seq_id].resize(n_embd);
1795+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
1796+
}
1797+
} break;
1798+
case LLAMA_POOLING_TYPE_RANK:
1799+
{
1800+
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
1801+
// wait for an encoder model that requires this pooling type in order to test it
1802+
// https://github.com/ggerganov/llama.cpp/pull/9510
1803+
GGML_ABORT("RANK pooling not implemented yet");
1804+
}
1805+
case LLAMA_POOLING_TYPE_UNSPECIFIED:
1806+
{
1807+
GGML_ABORT("unknown pooling type");
1808+
}
1809+
}
1810+
}
1811+
}
1812+
1813+
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
1814+
// overlap with device computation.
1815+
ggml_backend_sched_reset(sched.get());
1816+
1817+
return 0;
1818+
}
1819+
16581820
int llama_context_kv_self::decode(llama_batch & inp_batch) {
16591821
is_encoding = false;
16601822

@@ -2020,168 +2182,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
20202182
return 0;
20212183
}
20222184

2023-
int llama_context_kv_self::encode(llama_batch & inp_batch) {
2024-
is_encoding = true;
2025-
2026-
if (inp_batch.n_tokens == 0) {
2027-
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
2028-
return -1;
2029-
}
2030-
2031-
// temporary allocate memory for the input batch if needed
2032-
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
2033-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
2034-
2035-
const llama_batch & batch = batch_allocr.batch;
2036-
const int32_t n_tokens = batch.n_tokens;
2037-
2038-
const auto & hparams = model.hparams;
2039-
2040-
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
2041-
2042-
if (batch.token) {
2043-
for (int32_t i = 0; i < n_tokens; ++i) {
2044-
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
2045-
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
2046-
return -1;
2047-
}
2048-
}
2049-
}
2050-
2051-
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
2052-
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
2053-
2054-
if (t_compute_start_us == 0) {
2055-
t_compute_start_us = ggml_time_us();
2056-
}
2057-
2058-
n_queued_tokens += n_tokens;
2059-
2060-
const int64_t n_embd = hparams.n_embd;
2061-
2062-
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
2063-
2064-
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
2065-
2066-
// reserve output buffer
2067-
if (output_reserve(n_tokens) < n_tokens) {
2068-
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
2069-
return -2;
2070-
};
2071-
2072-
for (int32_t i = 0; i < n_tokens; ++i) {
2073-
output_ids[i] = i;
2074-
}
2075-
2076-
inp_embd_enc = NULL;
2077-
n_outputs = n_tokens;
2078-
2079-
//batch_manager->prepare(ubatch);
2080-
2081-
// TODO: do reserve
2082-
GGML_ASSERT(need_reserve == false);
2083-
2084-
ggml_backend_sched_reset(sched.get());
2085-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
2086-
2087-
auto ctx = graph_init();
2088-
auto res = graph_build(ctx, ubatch, false);
2089-
2090-
auto * gf = res.gf;
2091-
2092-
ggml_backend_sched_alloc_graph(sched.get(), gf);
2093-
2094-
input_set(ubatch);
2095-
2096-
const auto compute_status = graph_compute(gf, n_tokens > 1);
2097-
switch (compute_status) {
2098-
case GGML_STATUS_SUCCESS:
2099-
break;
2100-
case GGML_STATUS_ABORTED:
2101-
return 2;
2102-
case GGML_STATUS_ALLOC_FAILED:
2103-
return -2;
2104-
case GGML_STATUS_FAILED:
2105-
default:
2106-
return -3;
2107-
}
2108-
2109-
auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd;
2110-
2111-
// extract embeddings
2112-
if (t_embd) {
2113-
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
2114-
GGML_ASSERT(backend_embd != nullptr);
2115-
2116-
if (llama_model_has_decoder(&model)) {
2117-
embd_enc.resize(n_tokens*n_embd);
2118-
float * embd_out = embd_enc.data();
2119-
2120-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
2121-
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
2122-
2123-
// remember the sequence ids used during the encoding - needed for cross attention later
2124-
seq_ids_enc.resize(n_tokens);
2125-
for (int32_t i = 0; i < n_tokens; i++) {
2126-
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
2127-
llama_seq_id seq_id = ubatch.seq_id[i][s];
2128-
seq_ids_enc[i].insert(seq_id);
2129-
}
2130-
}
2131-
} else {
2132-
GGML_ASSERT(embd != nullptr);
2133-
2134-
switch (cparams.pooling_type) {
2135-
case LLAMA_POOLING_TYPE_NONE:
2136-
{
2137-
// extract token embeddings
2138-
GGML_ASSERT(embd != nullptr);
2139-
float * embd_out = embd;
2140-
2141-
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
2142-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
2143-
} break;
2144-
case LLAMA_POOLING_TYPE_MEAN:
2145-
case LLAMA_POOLING_TYPE_CLS:
2146-
case LLAMA_POOLING_TYPE_LAST:
2147-
{
2148-
// extract sequence embeddings
2149-
auto & embd_seq_out = embd_seq;
2150-
embd_seq_out.clear();
2151-
2152-
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
2153-
2154-
for (int32_t i = 0; i < n_tokens; i++) {
2155-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
2156-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
2157-
continue;
2158-
}
2159-
embd_seq_out[seq_id].resize(n_embd);
2160-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
2161-
}
2162-
} break;
2163-
case LLAMA_POOLING_TYPE_RANK:
2164-
{
2165-
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
2166-
// wait for an encoder model that requires this pooling type in order to test it
2167-
// https://github.com/ggerganov/llama.cpp/pull/9510
2168-
GGML_ABORT("RANK pooling not implemented yet");
2169-
}
2170-
case LLAMA_POOLING_TYPE_UNSPECIFIED:
2171-
{
2172-
GGML_ABORT("unknown pooling type");
2173-
}
2174-
}
2175-
}
2176-
}
2177-
2178-
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
2179-
// overlap with device computation.
2180-
ggml_backend_sched_reset(sched.get());
2181-
2182-
return 0;
2183-
}
2184-
21852185
llama_pos llama_context_kv_self::pos_max() const {
21862186
return kv_self.pos_max();
21872187
}

src/llama-context.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,30 +116,30 @@ struct llama_context : public llama_graph_i {
116116
// TODO: maybe remove this
117117
virtual void output_reorder();
118118

119-
// decode a batch of tokens by evaluating the transformer
120-
// in case of unsuccessful decoding (error or warning),
121-
// the kv_cache state will be returned to its original state
122-
// (for non-recurrent models) or cleaned (for recurrent models)
119+
// encode a batch of tokens by evaluating the encoder part of the transformer
123120
//
124121
// - lctx: llama context
125-
// - inp_batch: batch to evaluate
122+
// - batch: batch to evaluate
126123
//
127124
// return 0 on success
128125
// return positive int on warning
129126
// return negative int on error
130127
//
131-
virtual int decode(llama_batch & inp_batch) = 0;
128+
virtual int encode(llama_batch & inp_batch) = 0;
132129

133-
// encode a batch of tokens by evaluating the encoder part of the transformer
130+
// decode a batch of tokens by evaluating the transformer
131+
// in case of unsuccessful decoding (error or warning),
132+
// the kv_cache state will be returned to its original state
133+
// (for non-recurrent models) or cleaned (for recurrent models)
134134
//
135135
// - lctx: llama context
136-
// - batch: batch to evaluate
136+
// - inp_batch: batch to evaluate
137137
//
138138
// return 0 on success
139139
// return positive int on warning
140140
// return negative int on error
141141
//
142-
virtual int encode(llama_batch & inp_batch) = 0;
142+
virtual int decode(llama_batch & inp_batch) = 0;
143143

144144
//
145145
// graph build API (generic)
@@ -336,8 +336,8 @@ class llama_context_kv_self : public llama_context {
336336

337337
virtual void input_set(const llama_ubatch & ubatch) override;
338338

339-
virtual int decode(llama_batch & inp_batch) override;
340339
virtual int encode(llama_batch & inp_batch) override;
340+
virtual int decode(llama_batch & inp_batch) override;
341341

342342
// max token position across all sequences in the current context
343343
llama_pos pos_max() const;

0 commit comments

Comments
 (0)