11using System ;
22using System . Collections . Generic ;
3+ using System . Diagnostics ;
34
45namespace LLama . Native ;
56
@@ -105,6 +106,25 @@ private void GrowMaxSequences(int atLeast)
105106
106107 internal GroupDisposable ToNativeBatch ( out LLamaNativeBatch batch )
107108 {
109+ // Sanity checking
110+ #if DEBUG
111+ // Check every output logit position is actually generating logits for exactly one sequence
112+ foreach ( var ( seq , idx ) in _logitPositions )
113+ {
114+ Debug . Assert ( _logits [ idx ] != 0 ) ;
115+ Debug . Assert ( _sequenceIdCount [ idx ] == 1 ) ;
116+ Debug . Assert ( _sequenceIds [ idx ] [ 0 ] == seq ) ;
117+ }
118+
119+ // Check the reverse
120+ for ( var i = 0 ; i < _logits . Length ; i ++ )
121+ {
122+ var actual = _logitPositions . FindIndex ( x => x . Item2 == i ) >= 0 ;
123+ var expected = _logits [ i ] != 0 ;
124+ Debug . Assert ( actual == expected ) ;
125+ }
126+ #endif
127+
108128 // This group holds all of the memory pins
109129 var group = new GroupDisposable ( ) ;
110130
@@ -146,6 +166,7 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
146166 /// <returns>The index that the token was added at. Use this for GetLogitsIth</returns>
147167 public int Add ( LLamaToken token , LLamaPos pos , ReadOnlySpan < LLamaSeqId > sequences , bool logits )
148168 {
169+ // todo: token sharing in batch is broken?
149170 // Try to find this (token, position) combo somewhere in the batch to re-use it by adding this
150171 // sequence ID to the list.
151172 // Do **not** do this if this token wants logits, to prevent logits being shared between sequences.
@@ -171,9 +192,9 @@ public int Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequence
171192 if ( sequences . Length > SequenceCapacity )
172193 GrowMaxSequences ( sequences . Length ) ;
173194
174- // Store the position in the index, so it can be found later.
175- // We need to check that it's not already there in case we skipped the check above (because logits is true) .
176- if ( ! _index . ContainsKey ( ( token , pos ) ) )
195+ // Store the position in the index, so it can be found later. We don't want to share tokens when logits are being generated so
196+ // do not add to the index in that case .
197+ if ( ! logits && ! _index . ContainsKey ( ( token , pos ) ) )
177198 _index . Add ( ( token , pos ) , TokenCount ) ;
178199
179200 // Add the items to the arrays
0 commit comments