Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions LLama.Examples/Examples/QuantizeModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public static async Task Run()
{
Console.WriteLine("Quantization failed!");
}

await Task.CompletedTask;
}
}
}
2 changes: 1 addition & 1 deletion LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ public void Remove(LLamaPos start, LLamaPos end)
}

/// <summary>
/// Removes <see cref="count"/> tokens starting from <see cref="start"/>
/// Removes <paramref name="count"/> tokens starting from <paramref name="start"/>
/// </summary>
/// <param name="start">Start position (inclusive)</param>
/// <param name="count">Number of tokens</param>
Expand Down
1 change: 0 additions & 1 deletion LLama/Common/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ public class FixedSizeQueue<T>
private readonly T[] _buffer;
private int _start;
private int _count;
private T[]? _window;

// Minimum capacity for the temporary buffer used to expose a contiguous view.
private const int MinimumWindowSize = 4;
Expand Down
4 changes: 2 additions & 2 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ protected virtual void TryReuseMatchingPrefix()
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected abstract Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change for anything inheriting from executor base. So it needs to stay async.

protected abstract (bool, IReadOnlyList<string>) PostProcess(IInferenceParams inferenceParams, InferStateArgs args);

/// <summary>
/// The core inference logic.
Expand Down Expand Up @@ -338,7 +338,7 @@ public virtual async IAsyncEnumerable<string> InferAsync(string? text, IInferenc
yield return decoded;
}

var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
var (breakGeneration, extraOutputs) = PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
{
foreach (var item in extraOutputs)
Expand Down
6 changes: 5 additions & 1 deletion LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public override async Task SaveState(string filename)
await JsonSerializer.SerializeAsync(fs, state);
}
}

/// <inheritdoc />
public override async Task LoadState(string filename)
{
Expand Down Expand Up @@ -154,7 +155,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
protected override (bool, IReadOnlyList<string>) PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
Expand Down Expand Up @@ -205,7 +206,9 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
_pastTokensCount = pastTokensCount;

if (result != DecodeResult.Ok)
{
throw new LLamaDecodeError(result);
}

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand Down Expand Up @@ -250,6 +253,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In

return;
}

/// <summary>
/// The descriptor of the state of the instruct executor.
/// </summary>
Expand Down
25 changes: 17 additions & 8 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public override ExecutorBaseState GetStateData()
};
return state;
}

/// <inheritdoc />
public override Task LoadState(ExecutorBaseState data)
{
Expand All @@ -88,23 +89,23 @@ public override Task LoadState(ExecutorBaseState data)

return Task.CompletedTask;
}

/// <inheritdoc />
public override async Task SaveState(string filename)
{
var state = (InteractiveExecutorState)GetStateData();
using(var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
using (var fs = new FileStream(filename, FileMode.Create, FileAccess.Write))
{
await JsonSerializer.SerializeAsync(fs, state);
}
}

/// <inheritdoc />
public override async Task LoadState(string filename)
{
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state!);
}
using var fs = new FileStream(filename, FileMode.Open, FileAccess.Read);
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state!);
}

/// <summary>
Expand All @@ -122,7 +123,11 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args)
if (_is_prompt_run)
{
// When running the first input (prompt) in interactive mode, we should specially process it.
if (text == null) throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
if (text == null)
{
throw new ArgumentException("Prompt cannot be null to trigger continuation if a prompt has not been provided previously.");
}

if (!IsMultiModal)
{
_embed_inps = Context.Tokenize(text, true, true).ToList();
Expand Down Expand Up @@ -203,15 +208,19 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
protected override (bool, IReadOnlyList<string>) PostProcess(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput))
{
args.WaitForInput = true;
}

if (_pastTokensCount > 0 && args.WaitForInput)
{
return (true, Array.Empty<string>());
}
}

if (_embeds.Count > 0 && _embeds.Last().IsEndOfGeneration(Context.Vocab))
Expand Down
2 changes: 2 additions & 0 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
/// </summary>
/// <param name="model"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern uint llama_model_n_cls_out(SafeLlamaModelHandle model);

/// <summary>
Expand All @@ -444,6 +445,7 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
/// <param name="model"></param>
/// <param name="i"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern string? llama_model_cls_label(SafeLlamaModelHandle model, uint i);
#endregion

Expand Down
Loading