Skip to content

Commit 1fb5d4f

Browse files
compiladeggerganov
andcommitted
llama : apply suggestions
Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 7b7db0b commit 1fb5d4f

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

src/llama.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3018,7 +3018,7 @@ struct llama_sbatch {
30183018
return;
30193019
}
30203020
std::sort(ids.begin(), ids.end(),
3021-
[batch](size_t a, size_t b) {
3021+
[&batch](size_t a, size_t b) {
30223022
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
30233023
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
30243024
// sort by seq_id, then by pos
@@ -3050,7 +3050,6 @@ struct llama_sbatch {
30503050
if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
30513051
for (size_t i = 0; i < n_tokens; ++i) {
30523052
const size_t bi = ids[i];
3053-
const size_t s_len = seq.size();
30543053
const int32_t n_seqs = batch.n_seq_id[bi];
30553054
llama_seq_id * seq_ids = batch.seq_id[bi];
30563055
if (last_seq != nullptr) {
@@ -3067,7 +3066,7 @@ struct llama_sbatch {
30673066
}
30683067
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
30693068
seq.push_back(new_seq);
3070-
last_seq = &seq[s_len];
3069+
last_seq = &seq.back();
30713070
}
30723071
} else {
30733072
llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
@@ -15089,8 +15088,8 @@ static int llama_decode_internal(
1508915088

1509015089
while (lctx.sbatch.n_tokens > 0) {
1509115090
// For now, only use equal splits for recurrent model architectures
15092-
llama_ubatch u_batch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
15093-
const uint32_t n_tokens = u_batch.n_tokens;
15091+
llama_ubatch ubatch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
15092+
const uint32_t n_tokens = ubatch.n_tokens;
1509415093

1509515094
// count the outputs in this u_batch
1509615095
{
@@ -15099,9 +15098,9 @@ static int llama_decode_internal(
1509915098
if (n_outputs == n_tokens_all) {
1510015099
n_outputs_new = n_tokens;
1510115100
} else {
15102-
GGML_ASSERT(u_batch.output);
15101+
GGML_ASSERT(ubatch.output);
1510315102
for (uint32_t i = 0; i < n_tokens; i++) {
15104-
n_outputs_new += (int32_t) (u_batch.output[i] != 0);
15103+
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
1510515104
}
1510615105
}
1510715106

@@ -15122,7 +15121,7 @@ static int llama_decode_internal(
1512215121
kv_self.head = 0;
1512315122
}
1512415123

15125-
if (!llama_kv_cache_find_slot(kv_self, u_batch)) {
15124+
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
1512615125
return 1;
1512715126
}
1512815127

@@ -15141,7 +15140,7 @@ static int llama_decode_internal(
1514115140
ggml_backend_sched_reset(lctx.sched);
1514215141
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
1514315142

15144-
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);
15143+
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
1514515144

1514615145
// the output is always the last tensor in the graph
1514715146
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
@@ -15166,7 +15165,7 @@ static int llama_decode_internal(
1516615165

1516715166
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1516815167

15169-
llama_set_inputs(lctx, u_batch);
15168+
llama_set_inputs(lctx, ubatch);
1517015169

1517115170
llama_graph_compute(lctx, gf, n_threads);
1517215171

@@ -15229,7 +15228,7 @@ static int llama_decode_internal(
1522915228
embd_seq_out.clear();
1523015229

1523115230
for (uint32_t i = 0; i < n_tokens; i++) {
15232-
const llama_seq_id seq_id = u_batch.seq_id[i][0];
15231+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
1523315232
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
1523415233
continue;
1523515234
}

0 commit comments

Comments
 (0)