Skip to content

Commit 8fbd7e5

Browse files
committed
refactor(embeddings): get embeddings at the last token only to capture whole input and save time
-- noticed that the embeddings was the same for the full sequence but we're looping throgh all tokens which was inefficient
1 parent 80a2f0d commit 8fbd7e5

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

code/ac/llama/InstanceEmbedding.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,15 @@ void batchAddSeq(llama_batch& batch, std::span<const Token> tokens, llama_seq_id
100100
batch.pos [batch.n_tokens] = llama_pos(i);
101101
batch.n_seq_id[batch.n_tokens] = 1;
102102
batch.seq_id[batch.n_tokens][0] = seq_id;
103-
batch.logits [batch.n_tokens] = true;
103+
batch.logits [batch.n_tokens] = false;
104104

105105
batch.n_tokens++;
106106
}
107+
108+
// We want to extract the embeddings
109+
// for the last token in the sequence because
110+
// it will capture the all tokens in the sequence.
111+
batch.logits[batch.n_tokens - 1] = true;
107112
}
108113
}
109114

0 commit comments

Comments
 (0)