Skip to content

Commit b52b79b

Browse files
committed
context : move encode/decode to llama-context.cpp
1 parent 02ef4be commit b52b79b

File tree

3 files changed

+48
-57
lines changed

3 files changed

+48
-57
lines changed

src/llama-context.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3980,6 +3980,31 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
39803980
}
39813981
}
39823982

3983+
///
3984+
3985+
int32_t llama_encode(
3986+
struct llama_context * ctx,
3987+
struct llama_batch batch) {
3988+
const int ret = ctx->encode(batch);
3989+
if (ret != 0) {
3990+
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
3991+
}
3992+
3993+
return ret;
3994+
}
3995+
3996+
int32_t llama_decode(
3997+
struct llama_context * ctx,
3998+
struct llama_batch batch) {
3999+
const int ret = ctx->decode(batch);
4000+
if (ret != 0) {
4001+
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
4002+
}
4003+
4004+
return ret;
4005+
}
4006+
4007+
39834008
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
39844009
struct llama_context * ctx
39854010
) {

src/llama-context.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,30 @@ struct llama_context {
4545

4646
virtual ggml_context_ptr init();
4747

48+
// decode a batch of tokens by evaluating the transformer
49+
// in case of unsuccessful decoding (error or warning),
50+
// the kv_cache state will be returned to its original state
51+
// (for non-recurrent models) or cleaned (for recurrent models)
52+
//
53+
// - lctx: llama context
54+
// - inp_batch: batch to evaluate
55+
//
56+
// return 0 on success
57+
// return positive int on warning
58+
// return negative int on error
59+
//
4860
virtual int decode(llama_batch & inp_batch) = 0;
61+
62+
63+
// encode a batch of tokens by evaluating the encoder part of the transformer
64+
//
65+
// - lctx: llama context
66+
// - batch: batch to evaluate
67+
//
68+
// return 0 on success
69+
// return positive int on warning
70+
// return negative int on error
71+
//
4972
virtual int encode(llama_batch & inp_batch) = 0;
5073

5174
// graph build API (generic)

src/llama.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7401,39 +7401,6 @@ static struct ggml_cgraph * llama_build_graph(
74017401
return result;
74027402
}
74037403

7404-
// decode a batch of tokens by evaluating the transformer
7405-
// in case of unsuccessful decoding (error or warning),
7406-
// the kv_cache state will be returned to its original state
7407-
// (for non-recurrent models) or cleaned (for recurrent models)
7408-
//
7409-
// - lctx: llama context
7410-
// - inp_batch: batch to evaluate
7411-
//
7412-
// return 0 on success
7413-
// return positive int on warning
7414-
// return negative int on error
7415-
//
7416-
static int llama_decode_impl(
7417-
llama_context & lctx,
7418-
llama_batch inp_batch) {
7419-
return lctx.decode(inp_batch);
7420-
}
7421-
7422-
// encode a batch of tokens by evaluating the encoder part of the transformer
7423-
//
7424-
// - lctx: llama context
7425-
// - batch: batch to evaluate
7426-
//
7427-
// return 0 on success
7428-
// return positive int on warning
7429-
// return negative int on error
7430-
//
7431-
static int llama_encode_impl(
7432-
llama_context & lctx,
7433-
llama_batch inp_batch) {
7434-
return lctx.encode(inp_batch);
7435-
}
7436-
74377404
//
74387405
// interface implementation
74397406
//
@@ -7759,30 +7726,6 @@ struct llama_context * llama_new_context_with_model(
77597726
return llama_init_from_model(model, params);
77607727
}
77617728

7762-
///
7763-
7764-
int32_t llama_encode(
7765-
struct llama_context * ctx,
7766-
struct llama_batch batch) {
7767-
const int ret = llama_encode_impl(*ctx, batch);
7768-
if (ret != 0) {
7769-
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
7770-
}
7771-
7772-
return ret;
7773-
}
7774-
7775-
int32_t llama_decode(
7776-
struct llama_context * ctx,
7777-
struct llama_batch batch) {
7778-
const int ret = llama_decode_impl(*ctx, batch);
7779-
if (ret != 0) {
7780-
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
7781-
}
7782-
7783-
return ret;
7784-
}
7785-
77867729
//
77877730
// chat templates
77887731
//

0 commit comments

Comments
 (0)