@@ -18,16 +18,6 @@ namespace LLama;
18
18
public sealed partial class LLamaReranker
19
19
: IDisposable
20
20
{
21
- /// <summary>
22
- /// string BOS
23
- /// </summary>
24
- public string StrBOS { get ; }
25
- /// <summary>
26
- /// string EOS
27
- /// </summary>
28
- public string StrEOS { get ; }
29
-
30
-
31
21
/// <summary>
32
22
/// Dimension of embedding vectors
33
23
/// </summary>
@@ -54,8 +44,6 @@ public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logg
54
44
throw new NotSupportedException ( "Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank" ) ;
55
45
Context = weights . CreateContext ( @params , logger ) ;
56
46
NativeApi . llama_set_embeddings ( Context . NativeHandle , true ) ;
57
- StrBOS = Context . Vocab . LLamaTokenToString ( Context . Vocab . BOS , true ) ?? "<s>" ;
58
- StrEOS = Context . Vocab . LLamaTokenToString ( Context . Vocab . EOS , true ) ?? "</s>" ;
59
47
}
60
48
61
49
/// <inheritdoc />
@@ -65,7 +53,7 @@ public void Dispose()
65
53
}
66
54
67
55
/// <summary>
68
- /// Retrieve relevance scores for input and document by reranking
56
+ /// Retrieve relevance scores for input and documents by reranking, execute once.
69
57
/// </summary>
70
58
/// <param name="input"></param>
71
59
/// <param name="documents"></param>
@@ -74,22 +62,73 @@ public void Dispose()
74
62
/// <returns></returns>
75
63
/// <exception cref="RuntimeError"></exception>
76
64
/// <exception cref="NotSupportedException"></exception>
77
- public async Task < IReadOnlyList < float > > GetRelevanceScores ( string input , IReadOnlyList < string > documents , bool normalize = false , CancellationToken cancellationToken = default ) {
65
+ public async Task < IReadOnlyList < float > > GetRelevanceScores ( string input , IReadOnlyList < string > documents , bool normalize = false , CancellationToken cancellationToken = default )
66
+ {
78
67
List < float > scores = new List < float > ( documents . Count ) ;
79
- foreach ( var document in documents )
68
+ var batch = new LLamaBatch ( ) ;
69
+ var inputTokens = Context . Tokenize ( input ) ;
70
+ foreach ( var ( index , document ) in documents . Select ( ( item , index ) => ( index , item ) ) )
71
+ {
72
+ var docTokens = Context . Tokenize ( document ) ;
73
+ LLamaToken [ ] tokens = [ .. inputTokens , .. docTokens ] ;
74
+ for ( var i = 0 ; i < tokens . Length ; i ++ )
75
+ batch . Add ( tokens [ i ] , i , ( LLamaSeqId ) index , true ) ;
76
+ }
77
+
78
+ // clear previous kv_cache values
79
+ Context . NativeHandle . KvCacheClear ( ) ;
80
+
81
+ // Check if we should cancel the work, just before doing anything expensive (encode/decode)
82
+ cancellationToken . ThrowIfCancellationRequested ( ) ;
83
+
84
+ // Run model
85
+ switch ( Context . NativeHandle . ModelHandle . HasEncoder , Context . NativeHandle . ModelHandle . HasDecoder )
80
86
{
81
- var score = ( await GetRelevanceScoreWithTokenCount ( input , document , cancellationToken ) . ConfigureAwait ( false ) ) . Score ;
87
+ case ( true , false ) :
88
+ {
89
+ var result = await Context . EncodeAsync ( batch , cancellationToken ) ;
90
+ if ( result != EncodeResult . Ok )
91
+ throw new RuntimeError ( $ "Failed to encode: { result } ") ;
92
+ break ;
93
+ }
94
+
95
+ case ( false , true ) :
96
+ {
97
+ var result = await Context . DecodeAsync ( batch , cancellationToken ) ;
98
+ if ( result != DecodeResult . Ok )
99
+ throw new RuntimeError ( $ "Failed to decode: { result } ") ;
100
+ break ;
101
+ }
102
+
103
+ default :
104
+ throw new NotSupportedException ( "Unsupported model type" ) ;
105
+ }
106
+
107
+ for ( var i = 0 ; i < documents . Count ; i ++ )
108
+ {
109
+ var score = Context . NativeHandle . GetEmbeddingsSeq ( ( LLamaSeqId ) i ) [ 0 ] ;
82
110
scores . Add ( normalize ? Sigmoid ( score ) : score ) ;
83
111
}
112
+
113
+ Context . NativeHandle . KvCacheClear ( ) ;
114
+
84
115
return scores ;
85
116
}
86
117
87
-
88
- private async Task < ( float Score , int Tokens ) > GetRelevanceScoreWithTokenCount ( string input , string document , CancellationToken cancellationToken = default )
118
+ /// <summary>
119
+ /// Retrieve relevance score for input and document by reranking
120
+ /// </summary>
121
+ /// <param name="input"></param>
122
+ /// <param name="document"></param>
123
+ /// <param name="cancellationToken"></param>
124
+ /// <returns></returns>
125
+ /// <exception cref="RuntimeError"></exception>
126
+ /// <exception cref="NotSupportedException"></exception>
127
+ public async Task < ( float Score , int Tokens ) > GetRelevanceScoreWithTokenCount ( string input , string document , bool normalize = false , CancellationToken cancellationToken = default )
89
128
{
90
- var prompt = $ " { input } </s><s> { document } " ;
91
- // Add all of the tokens to the batch
92
- var tokens = Context . Tokenize ( prompt , special : true ) ;
129
+ var inputTokens = Context . Tokenize ( input ) ;
130
+ var docTokens = Context . Tokenize ( document ) ;
131
+ LLamaToken [ ] tokens = [ .. inputTokens , .. docTokens ] ;
93
132
var batch = new LLamaBatch ( ) ;
94
133
for ( var i = 0 ; i < tokens . Length ; i ++ )
95
134
batch . Add ( tokens [ i ] , i , LLamaSeqId . Zero , true ) ;
@@ -127,7 +166,7 @@ public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOn
127
166
128
167
Context . NativeHandle . KvCacheClear ( ) ;
129
168
130
- return ( score , tokens . Length ) ;
169
+ return ( normalize ? Sigmoid ( score ) : score , tokens . Length ) ;
131
170
}
132
171
133
172
private float Sigmoid ( float x )
0 commit comments