Skip to content

Commit 5f35b8e

Browse files
committed
Memory efficient context handling
1 parent 925ca06 commit 5f35b8e

File tree

5 files changed

+58
-101
lines changed

5 files changed

+58
-101
lines changed

LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ public sealed class LLamaSharpTextEmbeddingGenerator
1818
private readonly LLamaEmbedder _embedder;
1919
private readonly bool _ownsEmbedder;
2020

21+
private readonly ModelParams? @params;
22+
2123
/// <inheritdoc/>
2224
public int MaxTokens { get; }
2325

@@ -29,7 +31,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config)
2931
{
3032
MaxTokens = (int?)config.ContextSize ?? 2048;
3133

32-
var @params = new ModelParams(config.ModelPath)
34+
@params = new ModelParams(config.ModelPath)
3335
{
3436
ContextSize = config?.ContextSize ?? 2048,
3537
GpuLayerCount = config?.GpuLayerCount ?? 20,
@@ -57,7 +59,7 @@ public LLamaSharpTextEmbeddingGenerator(LLamaSharpConfig config, LLamaWeights we
5759
{
5860
MaxTokens = (int?)config.ContextSize ?? 2048;
5961

60-
var @params = new ModelParams(config.ModelPath)
62+
@params = new ModelParams(config.ModelPath)
6163
{
6264
ContextSize = config?.ContextSize ?? 2048,
6365
GpuLayerCount = config?.GpuLayerCount ?? 20,
@@ -103,8 +105,12 @@ public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationTok
103105
return new Embedding(embeddings.First());
104106
}
105107

106-
/// <inheritdoc/>
107-
public int CountTokens(string text) => _embedder.CountTokens(text);
108+
/// <summary>
109+
/// Count tokens in the input text
110+
/// </summary>
111+
/// <param name="text">input text</param>
112+
/// <returns></returns>
113+
public int CountTokens(string text) => _weights?.CountTokens(text, @params!) ?? 0;
108114

109115
/// <summary>
110116
/// Get the list of tokens for the input text
@@ -114,6 +120,6 @@ public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationTok
114120
/// <remarks>
115121
/// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks>
116122
/// <see cref="CountTokens(string)"/>
117-
public IReadOnlyList<string> GetTokens(string text) => _embedder.GetTokens(text);
123+
public IReadOnlyList<string> GetTokens(string text) => _weights?.GetTokens(text, @params!) ?? new List<string>();
118124
}
119125
}

LLama.KernelMemory/LlamaSharpTextGenerator.cs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public sealed class LlamaSharpTextGenerator
1919

2020
private readonly InferenceParams? _defaultInferenceParams;
2121

22+
private readonly ModelParams? @params;
23+
2224
public int MaxTokenTotal { get; }
2325

2426
/// <summary>
@@ -27,7 +29,7 @@ public sealed class LlamaSharpTextGenerator
2729
/// <param name="config">The configuration for LLamaSharp.</param>
2830
public LlamaSharpTextGenerator(LLamaSharpConfig config)
2931
{
30-
var parameters = new ModelParams(config.ModelPath)
32+
@params = new ModelParams(config.ModelPath)
3133
{
3234
ContextSize = config?.ContextSize ?? 2048,
3335
GpuLayerCount = config?.GpuLayerCount ?? 20,
@@ -38,11 +40,11 @@ public LlamaSharpTextGenerator(LLamaSharpConfig config)
3840
FlashAttention = true,
3941
UseMemorymap = true
4042
};
41-
_weights = LLamaWeights.LoadFromFile(parameters);
42-
_executor = new StatelessExecutor(_weights, parameters);
43+
_weights = LLamaWeights.LoadFromFile(@params);
44+
_executor = new StatelessExecutor(_weights, @params);
4345
_defaultInferenceParams = config!.DefaultInferenceParams;
4446
_ownsWeights = true;
45-
MaxTokenTotal = (int)parameters.ContextSize;
47+
MaxTokenTotal = (int)@params.ContextSize;
4648
}
4749

4850
/// <summary>
@@ -55,7 +57,7 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaSharpConfig config, St
5557
{
5658
InferenceParams? inferenceParams = config.DefaultInferenceParams;
5759
_weights = weights;
58-
var parameters = new ModelParams("")
60+
@params = new ModelParams("")
5961
{
6062
ContextSize = config?.ContextSize ?? 2048,
6163
GpuLayerCount = config?.GpuLayerCount ?? 20,
@@ -66,9 +68,9 @@ public LlamaSharpTextGenerator(LLamaWeights weights, LLamaSharpConfig config, St
6668
FlashAttention = true,
6769
UseMemorymap = true
6870
};
69-
_executor = executor ?? new StatelessExecutor(_weights, parameters);
71+
_executor = executor ?? new StatelessExecutor(_weights, @params);
7072
_defaultInferenceParams = inferenceParams;
71-
MaxTokenTotal = (int)parameters.ContextSize;
73+
MaxTokenTotal = (int)@params.ContextSize;
7274
}
7375

7476
/// <inheritdoc/>
@@ -122,8 +124,12 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
122124
};
123125
}
124126

125-
/// <inheritdoc/>
126-
public int CountTokens(string text) => _executor.CountTokens(text);
127+
/// <summary>
128+
/// Count tokens in the input text
129+
/// </summary>
130+
/// <param name="text">input text</param>
131+
/// <returns></returns>
132+
public int CountTokens(string text) => _weights?.CountTokens(text, @params!) ?? 0;
127133

128134
/// <summary>
129135
/// Get the list of tokens for the input text
@@ -133,7 +139,6 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
133139
/// <remarks>
134140
/// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks>
135141
/// <see cref="CountTokens(string)"/>
136-
public IReadOnlyList<string> GetTokens(string text) => _executor.GetTokens(text);
137-
142+
public IReadOnlyList<string> GetTokens(string text) => _weights?.GetTokens(text, @params!) ?? new List<string>();
138143
}
139144
}

LLama/LLamaEmbedder.cs

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -154,50 +154,4 @@ public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, Cancellati
154154

155155
return (results, tokens.Length);
156156
}
157-
158-
/// <summary>
159-
///
160-
/// </summary>
161-
/// <param name="text"></param>
162-
/// <returns></returns>
163-
public int CountTokens(string text)
164-
{
165-
// Ensure the context from last time is disposed (it always should be)
166-
if (!Context.NativeHandle.IsClosed)
167-
Context.Dispose();
168-
Context = _weights.CreateContext(_params, _logger);
169-
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
170-
int count = Context.Tokenize(text, special: true).Length;
171-
Context.Dispose();
172-
173-
return count;
174-
}
175-
176-
/// <summary>
177-
/// Get the list of tokens for the input text
178-
/// </summary>
179-
/// <param name="text">Input string to be tokenized</param>
180-
/// <returns>Read-only list of tokens for the input test</returns>
181-
/// <remarks>
182-
/// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks>
183-
/// <see cref="CountTokens(string)"/>
184-
public IReadOnlyList<string> GetTokens(string text)
185-
{
186-
// Ensure the context from last time is disposed (it always should be)
187-
if (!Context.NativeHandle.IsClosed)
188-
Context.Dispose();
189-
Context = _weights.CreateContext(_params, _logger);
190-
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
191-
192-
/* see relevant unit tests for important implementation notes regarding unicode */
193-
var context = Context;
194-
var numericTokens = context.Tokenize(text, special: true);
195-
var decoder = new StreamingTokenDecoder(context);
196-
var tokens = numericTokens
197-
.Select(x => { decoder.Add(x); return decoder.Read(); })
198-
.ToList();
199-
Context.Dispose();
200-
201-
return tokens;
202-
}
203157
}

LLama/LLamaStatelessExecutor.cs

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -169,44 +169,5 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
169169
throw new LLamaDecodeError(returnCode);
170170
}
171171
}
172-
173-
/// <inheritdoc/>
174-
public int CountTokens(string text)
175-
{
176-
// Ensure the context from last time is disposed (it always should be)
177-
if (!Context.NativeHandle.IsClosed)
178-
Context.Dispose();
179-
Context = _weights.CreateContext(_params, _logger);
180-
int count = Context.Tokenize(text, special: true).Length;
181-
Context.Dispose();
182-
183-
return count;
184-
}
185-
186-
/// <summary>
187-
/// Get the list of tokens for the input text
188-
/// </summary>
189-
/// <param name="text">Input string to be tokenized</param>
190-
/// <returns>Read-only list of tokens for the input test</returns>
191-
/// <remarks>
192-
/// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks>
193-
/// <see cref="CountTokens(string)"/>
194-
public IReadOnlyList<string> GetTokens(string text)
195-
{
196-
// Ensure the context from last time is disposed (it always should be)
197-
if (!Context.NativeHandle.IsClosed)
198-
Context.Dispose();
199-
Context = _weights.CreateContext(_params, _logger);
200-
201-
/* see relevant unit tests for important implementation notes regarding unicode */
202-
var numericTokens = Context.Tokenize(text, special: true);
203-
var decoder = new StreamingTokenDecoder(Context);
204-
var tokens = numericTokens
205-
.Select(x => { decoder.Add(x); return decoder.Read(); })
206-
.ToList();
207-
Context.Dispose();
208-
209-
return tokens ?? new List<string>();
210-
}
211172
}
212173
}

LLama/LLamaWeights.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using System.Threading;
56
using System.Threading.Tasks;
@@ -165,5 +166,35 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e
165166
{
166167
return NativeHandle.Tokenize(text, add_bos, special, encoding);
167168
}
169+
170+
/// <summary>
171+
/// Count the tokens in the input text
172+
/// </summary>
173+
/// <param name="text">input text</param>
174+
/// <param name="parameters">context parameters</param>
175+
/// <returns></returns>
176+
public int CountTokens(string text, IContextParams parameters)
177+
{
178+
using var context = CreateContext(parameters);
179+
var count = context.Tokenize(text, special: true).Length;
180+
return count;
181+
}
182+
183+
/// <summary>
184+
/// Get the list of tokens for the input text
185+
/// </summary>
186+
/// <param name="text">Input string to be tokenized</param>
187+
/// <param name="parameters">Context parameters</param>
188+
/// <returns>Read-only list of tokens for the input test</returns>
189+
/// <remarks>
190+
/// It throws if text is null and Includes empty stop token because addBos is left true to be consistent with the CountTokens implementation.</remarks>
191+
/// <see cref="CountTokens(string, IContextParams)"/>
192+
public IReadOnlyList<string> GetTokens(string text, IContextParams parameters)
193+
{
194+
using var context = CreateContext(parameters);
195+
var numericTokens = context.Tokenize(text, special: true);
196+
var decoder = new StreamingTokenDecoder(context);
197+
return numericTokens.Select(x => { decoder.Add(x); return decoder.Read(); }).ToList();
198+
}
168199
}
169200
}

0 commit comments

Comments
 (0)