Skip to content

Commit 86973cb

Browse files
committed
fix merge errors
1 parent 17f954c commit 86973cb

File tree

3 files changed

+43
-26
lines changed

3 files changed

+43
-26
lines changed

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ extern "C" {
994994
DEPRECATED(LLAMA_API int32_t llama_encode(
995995
struct llama_context * ctx,
996996
struct llama_batch batch), "use llama_batch_ext API instead");
997+
997998
LLAMA_API int32_t llama_encode_ext(
998999
struct llama_context * ctx,
9991000
struct llama_batch_ext * batch);
@@ -1005,6 +1006,7 @@ extern "C" {
10051006
DEPRECATED(LLAMA_API int32_t llama_decode(
10061007
struct llama_context * ctx,
10071008
struct llama_batch batch), "use llama_batch_ext API instead");
1009+
10081010
LLAMA_API int32_t llama_decode_ext(
10091011
struct llama_context * ctx,
10101012
struct llama_batch_ext * batch);

src/llama-context.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "llama-io.h"
55
#include "llama-mmap.h"
66
#include "llama-model.h"
7+
#include "llama-batch.h"
78
#include "llama-kv-cache.h"
89

910
#include <cassert>
@@ -980,16 +981,26 @@ bool llama_context::apply_adapter_cvec(
980981
}
981982

982983
int llama_context::encode(llama_batch & inp_batch) {
984+
// temporary allocate memory and convert llama_batch to llama_batch_ext
985+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
986+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
987+
return encode(*batch_allocr.batch);
988+
}
989+
990+
int llama_context::decode(llama_batch & inp_batch) {
991+
// temporary allocate memory and convert llama_batch to llama_batch_ext
992+
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
993+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
994+
return decode(*batch_allocr.batch);
995+
}
996+
997+
int llama_context::encode(llama_batch_ext & inp_batch) {
983998
if (inp_batch.n_tokens == 0) {
984999
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
9851000
return -1;
9861001
}
9871002

988-
// temporary allocate memory for the input batch if needed
989-
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
990-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
991-
992-
const llama_batch & batch = batch_allocr.batch;
1003+
llama_batch_ext & batch = inp_batch;
9931004
const int32_t n_tokens = batch.n_tokens;
9941005

9951006
const auto & hparams = model.hparams;
@@ -1132,17 +1143,13 @@ int llama_context::encode(llama_batch & inp_batch) {
11321143
return 0;
11331144
}
11341145

1135-
int llama_context::decode(llama_batch & inp_batch) {
1146+
int llama_context::decode(llama_batch_ext & inp_batch) {
11361147
if (inp_batch.n_tokens == 0) {
11371148
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
11381149
return -1;
11391150
}
11401151

1141-
// temporary allocate memory for the input batch if needed
1142-
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
1143-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
1144-
1145-
const llama_batch & batch = batch_allocr.batch;
1152+
llama_batch_ext & batch = inp_batch;
11461153

11471154
const auto & vocab = model.vocab;
11481155
const auto & hparams = model.hparams;
@@ -2714,26 +2721,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla
27142721

27152722
///
27162723

2724+
// deprecated
27172725
int32_t llama_encode(
2718-
llama_context * ctx,
2719-
llama_batch batch) {
2720-
const int ret = ctx->encode(batch);
2721-
if (ret != 0) {
2722-
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
2723-
}
2724-
2725-
return ret;
2726+
struct llama_context * ctx,
2727+
struct llama_batch inp_batch) {
2728+
return ctx->encode(inp_batch);
27262729
}
27272730

2731+
// deprecated
27282732
int32_t llama_decode(
2729-
llama_context * ctx,
2730-
llama_batch batch) {
2731-
const int ret = ctx->decode(batch);
2732-
if (ret != 0) {
2733-
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
2734-
}
2733+
struct llama_context * ctx,
2734+
struct llama_batch inp_batch) {
2735+
return ctx->decode(inp_batch);
2736+
}
2737+
2738+
int32_t llama_encode_ext(
2739+
struct llama_context * ctx,
2740+
struct llama_batch_ext * inp_batch) {
2741+
return ctx->encode(*inp_batch);
2742+
}
27352743

2736-
return ret;
2744+
int32_t llama_decode_ext(
2745+
struct llama_context * ctx,
2746+
struct llama_batch_ext * inp_batch) {
2747+
return ctx->decode(*inp_batch);
27372748
}
27382749

27392750
//

src/llama-context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,13 @@ struct llama_context {
8181
int32_t il_start,
8282
int32_t il_end);
8383

84+
// deprecated
8485
int encode(llama_batch & inp_batch);
8586
int decode(llama_batch & inp_batch);
8687

88+
int encode(llama_batch_ext & inp_batch);
89+
int decode(llama_batch_ext & inp_batch);
90+
8791
//
8892
// state save/load
8993
//

0 commit comments

Comments
 (0)