Skip to content

Commit 13b0247

Browse files
committed
batch : add TODOs
ggml-ci
1 parent 4c07964 commit 13b0247

File tree

4 files changed

+28
-15
lines changed

4 files changed

+28
-15
lines changed

src/llama-batch.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,29 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
292292

293293
GGML_ASSERT(batch.n_tokens > 0);
294294

295+
if (!batch.pos) {
296+
if (batch.seq_id) {
297+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
298+
return false;
299+
}
300+
}
301+
295302
if (batch.token) {
296303
for (int32_t i = 0; i < batch.n_tokens; ++i) {
297304
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
298305
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
299306
return false;
300307
}
308+
}
309+
}
301310

302-
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
303-
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
304-
return false;
311+
if (batch.seq_id) {
312+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
313+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
314+
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
315+
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
316+
return false;
317+
}
305318
}
306319
}
307320
}

src/llama-context.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
822822

823823
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
824824

825-
// TODO: fix sequence indexing
825+
// TODO: fix indexing [UBATCH_IDX]
826826
for (uint32_t i = 0; i < n_tokens; i++) {
827827
const llama_seq_id seq_id = ubatch.seq_id[i][0];
828828
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
@@ -838,6 +838,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
838838
auto & embd_seq_out = embd_seq;
839839
const uint32_t n_cls_out = hparams.n_cls_out;
840840

841+
// TODO: fix indexing [UBATCH_IDX]
841842
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
842843
const llama_seq_id seq_id = ubatch.seq_id[s][0];
843844
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
@@ -870,13 +871,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
870871
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
871872

872873
// remember the sequence ids used during the encoding - needed for cross attention later
873-
// TODO: the seuqence indexing here is likely not correct in the general case
874-
// probably works only for split_simple
875874
cross.seq_ids_enc.resize(n_tokens);
876875
for (uint32_t i = 0; i < n_tokens; i++) {
877876
cross.seq_ids_enc[i].clear();
878-
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
879-
llama_seq_id seq_id = ubatch.seq_id[i][s];
877+
for (int s = 0; s < batch.n_seq_id[i]; s++) {
878+
llama_seq_id seq_id = batch.seq_id[i][s];
880879
cross.seq_ids_enc[i].insert(seq_id);
881880
}
882881
}
@@ -896,13 +895,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
896895
return -1;
897896
}
898897

899-
if (!batch_inp.pos) {
900-
if (batch_inp.seq_id) {
901-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
902-
return -1;
903-
}
904-
}
905-
906898
// temporary allocate memory for the input batch if needed
907899
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
908900
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);

src/llama-graph.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
139139

140140
std::vector<uint64_t> sum(n_tokens, 0);
141141

142+
// TODO: fix indexing [UBATCH_IDX]
142143
for (int s = 0; s < n_seqs; ++s) {
143144
const llama_seq_id seq_id = ubatch->seq_id[s][0];
144145

@@ -156,6 +157,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
156157
}
157158
}
158159

160+
// TODO: fix indexing [UBATCH_IDX]
159161
for (int s = 0; s < n_seqs; ++s) {
160162
const llama_seq_id seq_id = ubatch->seq_id[s][0];
161163

@@ -180,6 +182,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
180182
uint32_t * data = (uint32_t *) cls->data;
181183
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
182184

185+
// TODO: fix indexing [UBATCH_IDX]
183186
for (int s = 0; s < n_seqs; ++s) {
184187
const llama_seq_id seq_id = ubatch->seq_id[s][0];
185188

@@ -210,6 +213,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
210213
std::vector<int> last_pos(n_tokens, -1);
211214
std::vector<int> last_row(n_tokens, -1);
212215

216+
// TODO: fix indexing [UBATCH_IDX]
213217
for (int s = 0; s < n_seqs; ++s) {
214218
const llama_seq_id seq_id = ubatch->seq_id[s][0];
215219

@@ -283,6 +287,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
283287
const int32_t ti = s0*n_seq_tokens + i;
284288
float f = -INFINITY;
285289

290+
// TODO: fix indexing [UBATCH_IDX]
286291
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
287292
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
288293
if (hparams.use_alibi) {
@@ -322,6 +327,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
322327
const int32_t ti = s0*n_seq_tokens + i;
323328
float f = -INFINITY;
324329

330+
// TODO: fix indexing [UBATCH_IDX]
325331
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
326332
if (ubatch->seq_id[s0][s] == seq_id) {
327333
if (hparams.use_alibi) {
@@ -377,6 +383,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
377383
for (int j = 0; j < n_tokens; ++j) {
378384
for (int i = 0; i < n_enc; ++i) {
379385
float f = -INFINITY;
386+
// TODO: fix indexing [UBATCH_IDX]
380387
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
381388
const llama_seq_id seq_id = ubatch->seq_id[j][s];
382389
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {

src/llama-kv-cache-unified.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
674674

675675
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
676676

677+
// TODO: fix indexing [UBATCH_IDX]
677678
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
678679
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
679680
}

0 commit comments

Comments
 (0)