Skip to content

Commit 371fdcd

Browse files
committed
optimize LLamaReranker function
1 parent 8a34866 commit 371fdcd

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

LLama.Unittest/LLamaRerankerTests.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ public async Task MostRelevantDocument()
6161
};
6262
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true);
6363

64+
Assert.NotNull(scores);
6465
Assert.True(documents.Length == scores.Count);
6566

66-
int maxIndex = scores
67-
.Select((score, index) => new { Score = score, Index = index })
68-
.MaxBy(x => x.Score)
69-
.Index;
67+
int maxIndex = scores.Select((score, index) => (score, index))
68+
.MaxBy(x => x.score)
69+
.index;
7070

7171
var maxScoreDocument = documents[maxIndex];
7272
Assert.Equal(documents[2], maxScoreDocument);

LLama/LLamaReranker.cs

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,6 @@ namespace LLama;
1818
public sealed partial class LLamaReranker
1919
: IDisposable
2020
{
21-
/// <summary>
22-
/// string BOS
23-
/// </summary>
24-
public string StrBOS { get; }
25-
/// <summary>
26-
/// string EOS
27-
/// </summary>
28-
public string StrEOS { get; }
29-
30-
3121
/// <summary>
3222
/// Dimension of embedding vectors
3323
/// </summary>
@@ -54,8 +44,6 @@ public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logg
5444
throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank");
5545
Context = weights.CreateContext(@params, logger);
5646
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
57-
StrBOS = Context.Vocab.LLamaTokenToString(Context.Vocab.BOS, true) ?? "<s>";
58-
StrEOS = Context.Vocab.LLamaTokenToString(Context.Vocab.EOS, true) ?? "</s>";
5947
}
6048

6149
/// <inheritdoc />
@@ -65,7 +53,7 @@ public void Dispose()
6553
}
6654

6755
/// <summary>
68-
/// Retrieve relevance scores for input and document by reranking
56+
/// Retrieve relevance scores for input and documents by reranking, execute once.
6957
/// </summary>
7058
/// <param name="input"></param>
7159
/// <param name="documents"></param>
@@ -74,22 +62,73 @@ public void Dispose()
7462
/// <returns></returns>
7563
/// <exception cref="RuntimeError"></exception>
7664
/// <exception cref="NotSupportedException"></exception>
77-
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default) {
65+
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default)
66+
{
7867
List<float> scores = new List<float>(documents.Count);
79-
foreach (var document in documents)
68+
var batch = new LLamaBatch();
69+
var inputTokens = Context.Tokenize(input);
70+
foreach (var (index, document) in documents.Select((item, index) => (index, item)))
71+
{
72+
var docTokens = Context.Tokenize(document);
73+
LLamaToken[] tokens = [.. inputTokens, .. docTokens];
74+
for (var i = 0; i < tokens.Length; i++)
75+
batch.Add(tokens[i], i, (LLamaSeqId)index, true);
76+
}
77+
78+
// clear previous kv_cache values
79+
Context.NativeHandle.KvCacheClear();
80+
81+
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
82+
cancellationToken.ThrowIfCancellationRequested();
83+
84+
// Run model
85+
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
8086
{
81-
var score = (await GetRelevanceScoreWithTokenCount(input, document, cancellationToken).ConfigureAwait(false)).Score;
87+
case (true, false):
88+
{
89+
var result = await Context.EncodeAsync(batch, cancellationToken);
90+
if (result != EncodeResult.Ok)
91+
throw new RuntimeError($"Failed to encode: {result}");
92+
break;
93+
}
94+
95+
case (false, true):
96+
{
97+
var result = await Context.DecodeAsync(batch, cancellationToken);
98+
if (result != DecodeResult.Ok)
99+
throw new RuntimeError($"Failed to decode: {result}");
100+
break;
101+
}
102+
103+
default:
104+
throw new NotSupportedException("Unsupported model type");
105+
}
106+
107+
for (var i = 0; i < documents.Count; i++)
108+
{
109+
var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)i)[0];
82110
scores.Add(normalize ? Sigmoid(score) : score);
83111
}
112+
113+
Context.NativeHandle.KvCacheClear();
114+
84115
return scores;
85116
}
86117

87-
88-
private async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, CancellationToken cancellationToken = default)
118+
/// <summary>
119+
/// Retrieve relevance score for input and document by reranking
120+
/// </summary>
121+
/// <param name="input"></param>
122+
/// <param name="document"></param>
123+
/// <param name="cancellationToken"></param>
124+
/// <returns></returns>
125+
/// <exception cref="RuntimeError"></exception>
126+
/// <exception cref="NotSupportedException"></exception>
127+
public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default)
89128
{
90-
var prompt = $"{input}</s><s>{document}";
91-
// Add all of the tokens to the batch
92-
var tokens = Context.Tokenize(prompt, special: true);
129+
var inputTokens = Context.Tokenize(input);
130+
var docTokens = Context.Tokenize(document);
131+
LLamaToken[] tokens = [..inputTokens, ..docTokens];
93132
var batch = new LLamaBatch();
94133
for (var i = 0; i < tokens.Length; i++)
95134
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);
@@ -127,7 +166,7 @@ public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOn
127166

128167
Context.NativeHandle.KvCacheClear();
129168

130-
return (score, tokens.Length);
169+
return (normalize ? Sigmoid(score) : score, tokens.Length);
131170
}
132171

133172
private float Sigmoid(float x)

0 commit comments

Comments
 (0)