Skip to content

Commit 46596ca

Browse files
committed
apply various in places
1 parent 1d6ba97 commit 46596ca

File tree

12 files changed

+144
-135
lines changed

12 files changed

+144
-135
lines changed

common/common.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,52 @@ void common_batch_add(
565565
const std::vector<llama_seq_id> & seq_ids,
566566
bool logits);
567567

568+
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
569+
// this is meant to be temporary
570+
struct common_batch {
571+
llama_batch_ext_ptr batch;
572+
struct batch_token {
573+
llama_token token;
574+
llama_seq_id seq_id;
575+
bool logits;
576+
};
577+
std::vector<batch_token> tokens;
578+
common_batch() = default;
579+
common_batch(int32_t n_tokens, int32_t n_seq_max) {
580+
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
581+
tokens.reserve(n_tokens);
582+
}
583+
void clear() {
584+
llama_batch_ext_clear(batch.get());
585+
tokens.clear();
586+
}
587+
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
588+
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
589+
tokens.push_back({token, seq_id, logits});
590+
}
591+
void set_logits_last() {
592+
if (!tokens.empty()) {
593+
llama_batch_ext_set_logits_last(batch.get());
594+
tokens.back().logits = true;
595+
}
596+
}
597+
int32_t get_n_tokens() const {
598+
return (int32_t)tokens.size();
599+
}
600+
llama_batch_ext * get() {
601+
return batch.get();
602+
}
603+
common_batch get_view(int32_t offset, int32_t n_tokens) {
604+
common_batch view;
605+
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
606+
view.tokens.reserve(n_tokens);
607+
for (int32_t i = 0; i < n_tokens; i++) {
608+
view.tokens.push_back(tokens[offset + i]);
609+
}
610+
return view;
611+
}
612+
};
613+
568614
//
569615
// Token utils
570616
//

examples/batched-bench/batched-bench.cpp

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,17 @@ int main(int argc, char ** argv) {
5959

6060
const int32_t n_kv_max = llama_n_ctx(ctx);
6161

62-
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
62+
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
6363

6464
// decode in batches of ctx_params.n_batch tokens
65-
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
66-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
67-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
68-
69-
llama_batch batch_view = {
70-
n_tokens,
71-
batch.token + i,
72-
nullptr,
73-
batch.pos + i,
74-
batch.n_seq_id + i,
75-
batch.seq_id + i,
76-
batch.logits + i,
77-
};
78-
79-
const int ret = llama_decode(ctx, batch_view);
65+
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
66+
const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch);
67+
for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) {
68+
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i));
69+
70+
llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens));
71+
72+
const int ret = llama_decode_ext(ctx, batch_view.get());
8073
if (ret != 0) {
8174
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
8275
return false;
@@ -91,7 +84,8 @@ int main(int argc, char ** argv) {
9184
// warm up
9285
{
9386
for (int i = 0; i < 16; ++i) {
94-
common_batch_add(batch, 0, i, { 0 }, false);
87+
const llama_seq_id seq_id = 0;
88+
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
9589
}
9690

9791
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -121,14 +115,14 @@ int main(int argc, char ** argv) {
121115
continue;
122116
}
123117

124-
common_batch_clear(batch);
118+
llama_batch_ext_clear(batch);
125119

126120
for (int i = 0; i < pp; ++i) {
127121
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
128-
common_batch_add(batch, 0, i, { j }, false);
122+
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
129123
}
130124
}
131-
batch.logits[batch.n_tokens - 1] = true;
125+
llama_batch_ext_set_logits_last(batch);
132126

133127
const auto t_pp_start = ggml_time_us();
134128

@@ -150,10 +144,10 @@ int main(int argc, char ** argv) {
150144
const auto t_tg_start = ggml_time_us();
151145

152146
for (int i = 0; i < tg; ++i) {
153-
common_batch_clear(batch);
147+
llama_batch_ext_clear(batch);
154148

155149
for (int j = 0; j < pl; ++j) {
156-
common_batch_add(batch, 0, pp + i, { j }, true);
150+
llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false);
157151
}
158152

159153
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -191,7 +185,7 @@ int main(int argc, char ** argv) {
191185
LOG("\n");
192186
llama_perf_context_print(ctx);
193187

194-
llama_batch_free(batch);
188+
llama_batch_ext_free(batch);
195189

196190
llama_free(ctx);
197191
llama_model_free(model);

examples/batched/batched.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
102102

103103
// create a llama_batch
104104
// we use this object to submit token data for decoding
105-
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
105+
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
106106

107107
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
108108
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -111,12 +111,12 @@ int main(int argc, char ** argv) {
111111

112112
// evaluate the initial prompt
113113
for (size_t i = 0; i < tokens_list.size(); ++i) {
114-
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
114+
llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false);
115115
}
116-
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
116+
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size());
117117

118118
if (llama_model_has_encoder(model)) {
119-
if (llama_encode(ctx, batch)) {
119+
if (llama_encode_ext(ctx, batch)) {
120120
LOG_ERR("%s : failed to eval\n", __func__);
121121
return 1;
122122
}
@@ -126,14 +126,14 @@ int main(int argc, char ** argv) {
126126
decoder_start_token_id = llama_vocab_bos(vocab);
127127
}
128128

129-
common_batch_clear(batch);
130-
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
129+
llama_batch_ext_clear(batch);
130+
llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false);
131131
}
132132

133133
// llama_decode will output logits only for the last token of the prompt
134-
batch.logits[batch.n_tokens - 1] = true;
134+
llama_batch_ext_set_logits_last(batch);
135135

136-
if (llama_decode(ctx, batch) != 0) {
136+
if (llama_decode_ext(ctx, batch) != 0) {
137137
LOG_ERR("%s: llama_decode() failed\n", __func__);
138138
return 1;
139139
}
@@ -155,16 +155,16 @@ int main(int argc, char ** argv) {
155155

156156
// remember the batch index of the last token for each parallel sequence
157157
// we need this to determine which logits to sample from
158-
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
158+
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
159159

160-
int n_cur = batch.n_tokens;
160+
int n_cur = llama_batch_ext_get_n_tokens(batch);
161161
int n_decode = 0;
162162

163163
const auto t_main_start = ggml_time_us();
164164

165165
while (n_cur <= n_predict) {
166166
// prepare the next batch
167-
common_batch_clear(batch);
167+
llama_batch_ext_clear(batch);
168168

169169
// sample the next token for each parallel sequence / stream
170170
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -193,23 +193,23 @@ int main(int argc, char ** argv) {
193193

194194
streams[i] += common_token_to_piece(ctx, new_token_id);
195195

196-
i_batch[i] = batch.n_tokens;
196+
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
197197

198198
// push this new token for next evaluation
199-
common_batch_add(batch, new_token_id, n_cur, { i }, true);
199+
llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, false);
200200

201201
n_decode += 1;
202202
}
203203

204204
// all streams are finished
205-
if (batch.n_tokens == 0) {
205+
if (llama_batch_ext_get_n_tokens(batch) == 0) {
206206
break;
207207
}
208208

209209
n_cur += 1;
210210

211211
// evaluate the current batch with the transformer model
212-
if (llama_decode(ctx, batch)) {
212+
if (llama_decode_ext(ctx, batch)) {
213213
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
214214
return 1;
215215
}
@@ -234,7 +234,7 @@ int main(int argc, char ** argv) {
234234

235235
fprintf(stderr, "\n");
236236

237-
llama_batch_free(batch);
237+
llama_batch_ext_free(batch);
238238

239239
llama_sampler_free(smpl);
240240
llama_free(ctx);

examples/cvector-generator/cvector-generator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345345
llama_kv_cache_clear(ctx);
346-
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
346+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
347+
if (llama_decode_ext(ctx, batch.get())) {
347348
fprintf(stderr, "%s : failed to eval\n", __func__);
348349
return false;
349350
}

examples/embedding/embedding.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,36 +25,36 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
2525
return lines;
2626
}
2727

28-
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
28+
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
2929
size_t n_tokens = tokens.size();
3030
for (size_t i = 0; i < n_tokens; i++) {
31-
common_batch_add(batch, tokens[i], i, { seq_id }, true);
31+
batch.add_text(tokens[i], i, seq_id, true);
3232
}
3333
}
3434

35-
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
35+
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
3636
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
3737
const struct llama_model * model = llama_get_model(ctx);
3838

3939
// clear previous kv_cache values (irrelevant for embeddings)
4040
llama_kv_cache_clear(ctx);
4141

4242
// run model
43-
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
43+
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
4444
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
4545
// encoder-only model
46-
if (llama_encode(ctx, batch) < 0) {
46+
if (llama_encode_ext(ctx, batch.get()) < 0) {
4747
LOG_ERR("%s : failed to encode\n", __func__);
4848
}
4949
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
5050
// decoder-only model
51-
if (llama_decode(ctx, batch) < 0) {
51+
if (llama_decode_ext(ctx, batch.get()) < 0) {
5252
LOG_ERR("%s : failed to decode\n", __func__);
5353
}
5454
}
5555

56-
for (int i = 0; i < batch.n_tokens; i++) {
57-
if (!batch.logits[i]) {
56+
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
57+
if (!batch.tokens[i].logits) {
5858
continue;
5959
}
6060

@@ -68,8 +68,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
6868
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
6969
} else {
7070
// try to get sequence embeddings - supported only when pooling_type is not NONE
71-
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
72-
embd_pos = batch.seq_id[i][0];
71+
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
72+
embd_pos = batch.tokens[i].seq_id;
7373
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
7474
}
7575

@@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
170170

171171
// initialize batch
172172
const int n_prompts = prompts.size();
173-
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
173+
struct common_batch batch = common_batch(n_batch, 1);
174174

175175
// count number of embeddings
176176
int n_embd_count = 0;
@@ -197,12 +197,12 @@ int main(int argc, char ** argv) {
197197
const uint64_t n_toks = inp.size();
198198

199199
// encode if at capacity
200-
if (batch.n_tokens + n_toks > n_batch) {
200+
if (batch.get_n_tokens() + n_toks > n_batch) {
201201
float * out = emb + e * n_embd;
202202
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
203-
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
203+
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s;
204204
s = 0;
205-
common_batch_clear(batch);
205+
batch.clear();
206206
}
207207

208208
// add to batch
@@ -318,7 +318,6 @@ int main(int argc, char ** argv) {
318318
llama_perf_context_print(ctx);
319319

320320
// clean up
321-
llama_batch_free(batch);
322321
llama_backend_free();
323322

324323
return 0;

examples/eval-callback/eval-callback.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) {
134134

135135
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
136136

137-
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
137+
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
138+
if (llama_decode_ext(ctx, batch.get())) {
138139
LOG_ERR("%s : failed to eval\n", __func__);
139140
return false;
140141
}

0 commit comments

Comments
 (0)