|
4 | 4 | #include "llama-io.h" |
5 | 5 | #include "llama-mmap.h" |
6 | 6 | #include "llama-model.h" |
| 7 | +#include "llama-batch.h" |
7 | 8 | #include "llama-kv-cache.h" |
8 | 9 |
|
9 | 10 | #include <cassert> |
@@ -980,16 +981,26 @@ bool llama_context::apply_adapter_cvec( |
980 | 981 | } |
981 | 982 |
|
982 | 983 | int llama_context::encode(llama_batch & inp_batch) { |
| 984 | + // temporary allocate memory and convert llama_batch to llama_batch_ext |
| 985 | + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences |
| 986 | + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); |
| 987 | + return encode(*batch_allocr.batch); |
| 988 | +} |
| 989 | + |
| 990 | +int llama_context::decode(llama_batch & inp_batch) { |
| 991 | + // temporary allocate memory and convert llama_batch to llama_batch_ext |
| 992 | + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences |
| 993 | + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); |
| 994 | + return decode(*batch_allocr.batch); |
| 995 | +} |
| 996 | + |
| 997 | +int llama_context::encode(llama_batch_ext & inp_batch) { |
983 | 998 | if (inp_batch.n_tokens == 0) { |
984 | 999 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
985 | 1000 | return -1; |
986 | 1001 | } |
987 | 1002 |
|
988 | | - // temporary allocate memory for the input batch if needed |
989 | | - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences |
990 | | - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); |
991 | | - |
992 | | - const llama_batch & batch = batch_allocr.batch; |
| 1003 | + llama_batch_ext & batch = inp_batch; |
993 | 1004 | const int32_t n_tokens = batch.n_tokens; |
994 | 1005 |
|
995 | 1006 | const auto & hparams = model.hparams; |
@@ -1132,17 +1143,13 @@ int llama_context::encode(llama_batch & inp_batch) { |
1132 | 1143 | return 0; |
1133 | 1144 | } |
1134 | 1145 |
|
1135 | | -int llama_context::decode(llama_batch & inp_batch) { |
| 1146 | +int llama_context::decode(llama_batch_ext & inp_batch) { |
1136 | 1147 | if (inp_batch.n_tokens == 0) { |
1137 | 1148 | LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); |
1138 | 1149 | return -1; |
1139 | 1150 | } |
1140 | 1151 |
|
1141 | | - // temporary allocate memory for the input batch if needed |
1142 | | - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences |
1143 | | - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); |
1144 | | - |
1145 | | - const llama_batch & batch = batch_allocr.batch; |
| 1152 | + llama_batch_ext & batch = inp_batch; |
1146 | 1153 |
|
1147 | 1154 | const auto & vocab = model.vocab; |
1148 | 1155 | const auto & hparams = model.hparams; |
@@ -2714,26 +2721,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla |
2714 | 2721 |
|
2715 | 2722 | /// |
2716 | 2723 |
|
| 2724 | +// deprecated |
2717 | 2725 | int32_t llama_encode( |
2718 | | - llama_context * ctx, |
2719 | | - llama_batch batch) { |
2720 | | - const int ret = ctx->encode(batch); |
2721 | | - if (ret != 0) { |
2722 | | - LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); |
2723 | | - } |
2724 | | - |
2725 | | - return ret; |
| 2726 | + struct llama_context * ctx, |
| 2727 | + struct llama_batch inp_batch) { |
| 2728 | + return ctx->encode(inp_batch); |
2726 | 2729 | } |
2727 | 2730 |
|
| 2731 | +// deprecated |
2728 | 2732 | int32_t llama_decode( |
2729 | | - llama_context * ctx, |
2730 | | - llama_batch batch) { |
2731 | | - const int ret = ctx->decode(batch); |
2732 | | - if (ret != 0) { |
2733 | | - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); |
2734 | | - } |
| 2733 | + struct llama_context * ctx, |
| 2734 | + struct llama_batch inp_batch) { |
| 2735 | + return ctx->decode(inp_batch); |
| 2736 | +} |
| 2737 | + |
| 2738 | +int32_t llama_encode_ext( |
| 2739 | + struct llama_context * ctx, |
| 2740 | + struct llama_batch_ext * inp_batch) { |
| 2741 | + return ctx->encode(*inp_batch); |
| 2742 | +} |
2735 | 2743 |
|
2736 | | - return ret; |
| 2744 | +int32_t llama_decode_ext( |
| 2745 | + struct llama_context * ctx, |
| 2746 | + struct llama_batch_ext * inp_batch) { |
| 2747 | + return ctx->decode(*inp_batch); |
2737 | 2748 | } |
2738 | 2749 |
|
2739 | 2750 | // |
|
0 commit comments