Skip to content

Commit a69f814

Browse files
committed
Merge branch 'feature-llamareranker'
2 parents 5996b40 + 6f4c53c commit a69f814

File tree

6 files changed

+269
-8
lines changed

6 files changed

+269
-8
lines changed

LLama.Unittest/Constants.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ internal static class Constants
77
public static readonly string GenerativeModelPath = "Models/Llama-3.2-1B-Instruct-Q4_0.gguf";
88
public static readonly string GenerativeModelPath2 = "Models/smollm-360m-instruct-add-basics-q8_0.gguf";
99
public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
10+
public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf";
1011

1112
public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
1213
public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf";

LLama.Unittest/LLama.Unittest.csproj

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@
3434

3535
<DownloadFile SourceUrl="https://huggingface.co/HuggingFaceTB/smollm-360M-instruct-v0.2-Q8_0-GGUF/resolve/main/smollm-360m-instruct-add-basics-q8_0.gguf" DestinationFolder="Models" DestinationFileName="smollm-360m-instruct-add-basics-q8_0.gguf" SkipUnchangedFiles="true">
3636
</DownloadFile>
37-
38-
<DownloadFile SourceUrl="https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf" DestinationFolder="Models" DestinationFileName="llava-v1.6-mistral-7b.Q3_K_XS.gguf" SkipUnchangedFiles="true">
37+
38+
<DownloadFile SourceUrl="https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-FP16.gguf" DestinationFolder="Models" DestinationFileName="jina-reranker-v1-tiny-en-FP16.gguf" SkipUnchangedFiles="true">
39+
</DownloadFile>
40+
41+
<DownloadFile SourceUrl="https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf" DestinationFolder="Models" DestinationFileName="llava-v1.6-mistral-7b.Q3_K_XS.gguf" SkipUnchangedFiles="true">
3942
</DownloadFile>
4043

4144
<DownloadFile SourceUrl="https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/mmproj-model-f16.gguf" DestinationFolder="Models" DestinationFileName="mmproj-model-f16.gguf" SkipUnchangedFiles="true">
@@ -63,6 +66,9 @@
6366
<None Update="Models\Llama-3.2-1B-Instruct-Q4_0.gguf">
6467
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
6568
</None>
69+
<None Update="Models\jina-reranker-v1-tiny-en-FP16.gguf">
70+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
71+
</None>
6672
<None Update="Models\smollm-360m-instruct-add-basics-q8_0.gguf">
6773
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
6874
</None>

LLama.Unittest/LLamaRerankerTests.cs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
using LLama.Common;
2+
using LLama.Extensions;
3+
using LLama.Native;
4+
using Microsoft.Extensions.AI;
5+
using System.Runtime.InteropServices;
6+
using Xunit.Abstractions;
7+
8+
namespace LLama.Unittest;
9+
10+
public sealed class LLamaRerankerTests
11+
{
12+
private readonly ITestOutputHelper _testOutputHelper;
13+
private readonly LLamaReranker _reranker;
14+
public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
15+
{
16+
_testOutputHelper = testOutputHelper;
17+
18+
var @params = new ModelParams(Constants.RerankingModelPath)
19+
{
20+
ContextSize = 0,
21+
PoolingType = LLamaPoolingType.Rank,
22+
GpuLayerCount = Constants.CIGpuLayerCount,
23+
24+
};
25+
using var weights = LLamaWeights.LoadFromFile(@params);
26+
_reranker = new LLamaReranker(weights, @params);
27+
}
28+
29+
[Fact]
30+
public async Task CompareRerankingScore()
31+
{
32+
33+
34+
var input = "what is panda?";
35+
var documents = new string[] {
36+
"hi",
37+
"it's a bear",
38+
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
39+
"sometimes called a panda bear or simply panda",
40+
"is a bear species endemic to China.")
41+
};
42+
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false);
43+
44+
Assert.True(documents.Length == scores.Count);
45+
46+
_testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}");
47+
_testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}");
48+
_testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}");
49+
}
50+
51+
[Fact]
52+
public async Task MostRelevantDocument()
53+
{
54+
var input = "what is panda?";
55+
var documents = new string[] {
56+
"hi",
57+
"it's a bear",
58+
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
59+
"sometimes called a panda bear or simply panda",
60+
"is a bear species endemic to China.")
61+
};
62+
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true);
63+
64+
Assert.True(documents.Length == scores.Count);
65+
66+
int maxIndex = scores
67+
.Select((score, index) => new { Score = score, Index = index })
68+
.MaxBy(x => x.Score)
69+
.Index;
70+
71+
var maxScoreDocument = documents[maxIndex];
72+
Assert.Equal(documents[2], maxScoreDocument);
73+
}
74+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
using System.Text;
2+
using System.Xml.Linq;
3+
using LLama.Common;
4+
using LLama.Extensions;
5+
using Microsoft.Extensions.Logging;
6+
7+
8+
namespace LLama.Unittest.Native;
9+
10+
public class SafeLlamaModelHandleVocabularyTests
11+
{
12+
private readonly LLamaWeights _model;
13+
14+
public SafeLlamaModelHandleVocabularyTests()
15+
{
16+
var @params = new ModelParams(Constants.RerankingModelPath)
17+
{
18+
ContextSize = 0,
19+
PoolingType = LLama.Native.LLamaPoolingType.Rank,
20+
GpuLayerCount = Constants.CIGpuLayerCount
21+
};
22+
_model = LLamaWeights.LoadFromFile(@params);
23+
}
24+
25+
[Fact]
26+
public void GetLLamaTokenString()
27+
{
28+
var bos = _model.Vocab.BOS;
29+
var eos = _model.Vocab.EOS;
30+
31+
var bosStr = _model.Vocab.LLamaTokenToString(bos, true);
32+
var eosStr = _model.Vocab.LLamaTokenToString(eos, true);
33+
34+
Assert.Equal("<s>", bosStr);
35+
Assert.Equal("</s>", eosStr);
36+
}
37+
}

LLama/LLamaReranker.cs

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using System.Linq;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using LLama.Abstractions;
9+
using LLama.Exceptions;
10+
using LLama.Native;
11+
using Microsoft.Extensions.Logging;
12+
13+
namespace LLama;
14+
15+
/// <summary>
16+
/// Get rank scores between prompt and documents
17+
/// </summary>
18+
public sealed partial class LLamaReranker
19+
: IDisposable
20+
{
21+
/// <summary>
22+
/// string BOS
23+
/// </summary>
24+
public string StrBOS { get; }
25+
/// <summary>
26+
/// string EOS
27+
/// </summary>
28+
public string StrEOS { get; }
29+
30+
31+
/// <summary>
32+
/// Dimension of embedding vectors
33+
/// </summary>
34+
public int EmbeddingSize => Context.EmbeddingSize;
35+
36+
/// <summary>
37+
/// LLama Context
38+
/// </summary>
39+
public LLamaContext Context { get; }
40+
41+
/// <summary>
42+
/// Create a new reranker, using the given LLamaWeights
43+
/// </summary>
44+
/// <param name="weights"></param>
45+
/// <param name="params"></param>
46+
/// <param name="logger"></param>
47+
public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
48+
{
49+
if (@params.UBatchSize != @params.BatchSize)
50+
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params));
51+
if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true })
52+
throw new NotSupportedException("Computing rank in encoder-decoder models is not supported");
53+
if (@params.PoolingType != LLamaPoolingType.Rank)
54+
throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank");
55+
Context = weights.CreateContext(@params, logger);
56+
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
57+
StrBOS = Context.Vocab.LLamaTokenToString(Context.Vocab.BOS, true) ?? "<s>";
58+
StrEOS = Context.Vocab.LLamaTokenToString(Context.Vocab.EOS, true) ?? "</s>";
59+
}
60+
61+
/// <inheritdoc />
62+
public void Dispose()
63+
{
64+
Context.Dispose();
65+
}
66+
67+
/// <summary>
68+
/// Retrieve relevance scores for input and document by reranking
69+
/// </summary>
70+
/// <param name="input"></param>
71+
/// <param name="documents"></param>
72+
/// <param name="normalize">Whether to normalize the score to the range (0, 1)</param>
73+
/// <param name="cancellationToken"></param>
74+
/// <returns></returns>
75+
/// <exception cref="RuntimeError"></exception>
76+
/// <exception cref="NotSupportedException"></exception>
77+
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default) {
78+
List<float> scores = new List<float>(documents.Count);
79+
foreach (var document in documents)
80+
{
81+
var score = (await GetRelevanceScoreWithTokenCount(input, document, cancellationToken).ConfigureAwait(false)).Score;
82+
scores.Add(normalize ? Sigmoid(score) : score);
83+
}
84+
return scores;
85+
}
86+
87+
88+
private async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, CancellationToken cancellationToken = default)
89+
{
90+
var prompt = $"{input}</s><s>{document}";
91+
// Add all of the tokens to the batch
92+
var tokens = Context.Tokenize(prompt, special: true);
93+
var batch = new LLamaBatch();
94+
for (var i = 0; i < tokens.Length; i++)
95+
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);
96+
97+
// clear previous kv_cache values
98+
Context.NativeHandle.KvCacheClear();
99+
100+
// Check if we should cancel the work, just before doing anything expensive (encode/decode)
101+
cancellationToken.ThrowIfCancellationRequested();
102+
103+
// Run model
104+
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
105+
{
106+
case (true, false):
107+
{
108+
var result = await Context.EncodeAsync(batch, cancellationToken);
109+
if (result != EncodeResult.Ok)
110+
throw new RuntimeError($"Failed to encode: {result}");
111+
break;
112+
}
113+
114+
case (false, true):
115+
{
116+
var result = await Context.DecodeAsync(batch, cancellationToken);
117+
if (result != DecodeResult.Ok)
118+
throw new RuntimeError($"Failed to decode: {result}");
119+
break;
120+
}
121+
122+
default:
123+
throw new NotSupportedException("Unsupported model type");
124+
}
125+
126+
var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];
127+
128+
Context.NativeHandle.KvCacheClear();
129+
130+
return (score, tokens.Length);
131+
}
132+
133+
private float Sigmoid(float x)
134+
{
135+
return (float)(1 / (1 + Math.Exp(-x)));
136+
}
137+
}

LLama/Native/SafeLlamaModelHandle.cs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,18 @@ internal Vocabulary(SafeLlamaModelHandle model)
651651
_model = model;
652652
}
653653

654-
private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
654+
private static LLamaToken? Normalize(LLamaToken token)
655+
{
656+
return token == -1 ? null : token;
657+
}
658+
659+
/// <summary>
660+
/// Translate LLamaToken to String
661+
/// </summary>
662+
/// <param name="token"></param>
663+
/// <param name="isSpecialToken"></param>
664+
/// <returns></returns>
665+
public string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
655666
{
656667
if (!token.HasValue)
657668
return null;
@@ -676,11 +687,6 @@ internal Vocabulary(SafeLlamaModelHandle model)
676687
return Encoding.UTF8.GetStringFromSpan(slice);
677688
}
678689

679-
private static LLamaToken? Normalize(LLamaToken token)
680-
{
681-
return token == -1 ? null : token;
682-
}
683-
684690
/// <summary>
685691
/// Total number of tokens in this vocabulary
686692
/// </summary>

0 commit comments

Comments
 (0)