@@ -3018,7 +3018,7 @@ struct llama_sbatch {
3018
3018
return;
3019
3019
}
3020
3020
std::sort(ids.begin(), ids.end(),
3021
- [batch](size_t a, size_t b) {
3021
+ [& batch](size_t a, size_t b) {
3022
3022
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
3023
3023
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
3024
3024
// sort by seq_id, then by pos
@@ -3050,7 +3050,6 @@ struct llama_sbatch {
3050
3050
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
3051
3051
for (size_t i = 0; i < n_tokens; ++i) {
3052
3052
const size_t bi = ids[i];
3053
- const size_t s_len = seq.size();
3054
3053
const int32_t n_seqs = batch.n_seq_id[bi];
3055
3054
llama_seq_id * seq_ids = batch.seq_id[bi];
3056
3055
if (last_seq != nullptr) {
@@ -3067,7 +3066,7 @@ struct llama_sbatch {
3067
3066
}
3068
3067
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
3069
3068
seq.push_back(new_seq);
3070
- last_seq = &seq[s_len] ;
3069
+ last_seq = &seq.back() ;
3071
3070
}
3072
3071
} else {
3073
3072
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
@@ -15089,8 +15088,8 @@ static int llama_decode_internal(
15089
15088
15090
15089
while (lctx.sbatch.n_tokens > 0) {
15091
15090
// For now, only use equal splits for recurrent model architectures
15092
- llama_ubatch u_batch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
15093
- const uint32_t n_tokens = u_batch .n_tokens;
15091
+ llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
15092
+ const uint32_t n_tokens = ubatch .n_tokens;
15094
15093
15095
15094
// count the outputs in this u_batch
15096
15095
{
@@ -15099,9 +15098,9 @@ static int llama_decode_internal(
15099
15098
if (n_outputs == n_tokens_all) {
15100
15099
n_outputs_new = n_tokens;
15101
15100
} else {
15102
- GGML_ASSERT(u_batch .output);
15101
+ GGML_ASSERT(ubatch .output);
15103
15102
for (uint32_t i = 0; i < n_tokens; i++) {
15104
- n_outputs_new += (int32_t) (u_batch .output[i] != 0);
15103
+ n_outputs_new += (int32_t) (ubatch .output[i] != 0);
15105
15104
}
15106
15105
}
15107
15106
@@ -15122,7 +15121,7 @@ static int llama_decode_internal(
15122
15121
kv_self.head = 0;
15123
15122
}
15124
15123
15125
- if (!llama_kv_cache_find_slot(kv_self, u_batch )) {
15124
+ if (!llama_kv_cache_find_slot(kv_self, ubatch )) {
15126
15125
return 1;
15127
15126
}
15128
15127
@@ -15141,7 +15140,7 @@ static int llama_decode_internal(
15141
15140
ggml_backend_sched_reset(lctx.sched);
15142
15141
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
15143
15142
15144
- ggml_cgraph * gf = llama_build_graph(lctx, u_batch , false);
15143
+ ggml_cgraph * gf = llama_build_graph(lctx, ubatch , false);
15145
15144
15146
15145
// the output is always the last tensor in the graph
15147
15146
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
@@ -15166,7 +15165,7 @@ static int llama_decode_internal(
15166
15165
15167
15166
ggml_backend_sched_alloc_graph(lctx.sched, gf);
15168
15167
15169
- llama_set_inputs(lctx, u_batch );
15168
+ llama_set_inputs(lctx, ubatch );
15170
15169
15171
15170
llama_graph_compute(lctx, gf, n_threads);
15172
15171
@@ -15229,7 +15228,7 @@ static int llama_decode_internal(
15229
15228
embd_seq_out.clear();
15230
15229
15231
15230
for (uint32_t i = 0; i < n_tokens; i++) {
15232
- const llama_seq_id seq_id = u_batch .seq_id[i][0];
15231
+ const llama_seq_id seq_id = ubatch .seq_id[i][0];
15233
15232
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
15234
15233
continue;
15235
15234
}
0 commit comments