Skip to content

Commit 079410c

Browse files
authored
Merge pull request #973 from martindevans/fix_null_sampler_pipeline
Non-Null Default `SamplingPipeline`
2 parents 3f176be + 20f5485 commit 079410c

File tree

9 files changed

+62
-99
lines changed

9 files changed

+62
-99
lines changed

LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ private void InitializeParamsAndModel()
8383
Prompt = File.ReadAllText(Constants.TextCompletionPromptsFilePath).Substring(0, PromptAndContextLength.Item1);
8484
InferenceParams = new InferenceParams()
8585
{
86-
Temperature = 0.6f,
8786
MaxTokens = 1 // Only prefill, no generation here.
8887
};
8988

LLama.Examples/Examples/ChatChineseGB2312.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@ public static async Task Run()
5555
session
5656
.WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤"));
5757

58-
InferenceParams inferenceParams = new InferenceParams()
58+
var inferenceParams = new InferenceParams
5959
{
60-
Temperature = 0.9f,
6160
AntiPrompts = new List<string> { "用户:" }
6261
};
6362

LLama.Examples/Examples/InteractiveModeExecute.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using LLama.Common;
2+
using LLama.Sampling;
23

34
namespace LLama.Examples.Examples
45
{
@@ -25,7 +26,16 @@ public static async Task Run()
2526

2627
Console.Write(prompt);
2728

28-
var inferenceParams = new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" }, MaxTokens = 128 };
29+
var inferenceParams = new InferenceParams
30+
{
31+
AntiPrompts = new List<string> { "User:" },
32+
MaxTokens = 128,
33+
34+
SamplingPipeline = new DefaultSamplingPipeline
35+
{
36+
Temperature = 0.6f
37+
}
38+
};
2939

3040
while (true)
3141
{

LLama.KernelMemory/LlamaSharpTextGenerator.cs

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using LLama;
22
using LLama.Common;
3+
using LLama.Sampling;
34
using Microsoft.KernelMemory.AI;
45

56
namespace LLamaSharp.KernelMemory
@@ -86,25 +87,31 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
8687
return defaultParams with
8788
{
8889
AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(),
89-
Temperature = (float)options.Temperature,
9090
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
91-
FrequencyPenalty = (float)options.FrequencyPenalty,
92-
PresencePenalty = (float)options.PresencePenalty,
93-
TopP = (float)options.NucleusSampling
91+
92+
SamplingPipeline = new DefaultSamplingPipeline()
93+
{
94+
Temperature = (float)options.Temperature,
95+
AlphaFrequency = (float)options.FrequencyPenalty,
96+
AlphaPresence = (float)options.PresencePenalty,
97+
TopP = (float)options.NucleusSampling,
98+
}
9499
};
95100
}
96-
else
101+
102+
return new InferenceParams
97103
{
98-
return new InferenceParams
104+
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
105+
MaxTokens = options.MaxTokens ?? 1024,
106+
107+
SamplingPipeline = new DefaultSamplingPipeline()
99108
{
100-
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
101109
Temperature = (float)options.Temperature,
102-
MaxTokens = options.MaxTokens ?? 1024,
103-
FrequencyPenalty = (float)options.FrequencyPenalty,
104-
PresencePenalty = (float)options.PresencePenalty,
110+
AlphaFrequency = (float)options.FrequencyPenalty,
111+
AlphaPresence = (float)options.PresencePenalty,
105112
TopP = (float)options.NucleusSampling,
106-
};
107-
}
113+
}
114+
};
108115
}
109116

110117
/// <inheritdoc/>

LLama.SemanticKernel/ExtensionMethods.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using LLama.Sampling;
12
using Microsoft.SemanticKernel.ChatCompletion;
23
using AuthorRole = LLama.Common.AuthorRole;
34

@@ -45,12 +46,16 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL
4546
};
4647
return new LLama.Common.InferenceParams
4748
{
48-
Temperature = (float)requestSettings.Temperature,
49-
TopP = (float)requestSettings.TopP,
50-
PresencePenalty = (float)requestSettings.PresencePenalty,
51-
FrequencyPenalty = (float)requestSettings.FrequencyPenalty,
5249
AntiPrompts = antiPrompts,
53-
MaxTokens = requestSettings.MaxTokens ?? -1
50+
MaxTokens = requestSettings.MaxTokens ?? -1,
51+
52+
SamplingPipeline = new DefaultSamplingPipeline()
53+
{
54+
Temperature = (float)requestSettings.Temperature,
55+
TopP = (float)requestSettings.TopP,
56+
AlphaPresence = (float)requestSettings.PresencePenalty,
57+
AlphaFrequency = (float)requestSettings.FrequencyPenalty,
58+
}
5459
};
5560
}
5661
}

LLama.WebAPI/Controllers/ChatController.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using LLama.WebAPI.Models;
33
using LLama.WebAPI.Services;
44
using Microsoft.AspNetCore.Mvc;
5-
using System;
65

76
namespace LLama.WebAPI.Controllers
87
{

LLama.WebAPI/Services/StatefulChatService.cs

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-

21
using LLama.WebAPI.Models;
3-
using Microsoft;
4-
using System.Runtime.CompilerServices;
2+
using LLama.Sampling;
53

64
namespace LLama.WebAPI.Services;
75

8-
public class StatefulChatService : IDisposable
6+
public sealed class StatefulChatService
7+
: IDisposable
98
{
109
private readonly ChatSession _session;
1110
private readonly LLamaContext _context;
@@ -47,10 +46,14 @@ public async Task<string> Send(SendMessageInput input)
4746
_logger.LogInformation("Input: {text}", input.Text);
4847
var outputs = _session.ChatAsync(
4948
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
50-
new Common.InferenceParams()
49+
new Common.InferenceParams
5150
{
52-
RepeatPenalty = 1.0f,
53-
AntiPrompts = new string[] { "User:" },
51+
AntiPrompts = [ "User:" ],
52+
53+
SamplingPipeline = new DefaultSamplingPipeline
54+
{
55+
RepeatPenalty = 1.0f
56+
}
5457
});
5558

5659
var result = "";
@@ -74,11 +77,15 @@ public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
7477
_logger.LogInformation(input.Text);
7578

7679
var outputs = _session.ChatAsync(
77-
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!)
78-
, new Common.InferenceParams()
80+
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
81+
new Common.InferenceParams
7982
{
80-
RepeatPenalty = 1.0f,
81-
AntiPrompts = new string[] { "User:" },
83+
AntiPrompts = [ "User:" ],
84+
85+
SamplingPipeline = new DefaultSamplingPipeline
86+
{
87+
RepeatPenalty = 1.0f
88+
}
8289
});
8390

8491
await foreach (var output in outputs)

LLama.WebAPI/Services/StatelessChatService.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using LLama.Common;
2-
using Microsoft.AspNetCore.Http;
1+
using LLama.Common;
32
using System.Text;
43
using static LLama.LLamaTransforms;
54

LLama/Common/InferenceParams.cs

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public record InferenceParams
1313
: IInferenceParams
1414
{
1515
/// <summary>
16-
/// number of tokens to keep from initial prompt
16+
/// number of tokens to keep from initial prompt when applying context shifting
1717
/// </summary>
1818
public int TokensKeep { get; set; } = 0;
1919

@@ -23,75 +23,13 @@ public record InferenceParams
2323
/// </summary>
2424
public int MaxTokens { get; set; } = -1;
2525

26-
/// <summary>
27-
/// logit bias for specific tokens
28-
/// </summary>
29-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
30-
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;
31-
3226
/// <summary>
3327
/// Sequences where the model will stop generating further tokens.
3428
/// </summary>
3529
public IReadOnlyList<string> AntiPrompts { get; set; } = [];
3630

3731
/// <inheritdoc />
38-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
39-
public int TopK { get; set; } = 40;
40-
41-
/// <inheritdoc />
42-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
43-
public float TopP { get; set; } = 0.95f;
44-
45-
/// <inheritdoc />
46-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
47-
public float MinP { get; set; } = 0.05f;
48-
49-
/// <inheritdoc />
50-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
51-
public float TfsZ { get; set; } = 1.0f;
52-
53-
/// <inheritdoc />
54-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
55-
public float TypicalP { get; set; } = 1.0f;
56-
57-
/// <inheritdoc />
58-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
59-
public float Temperature { get; set; } = 0.8f;
60-
61-
/// <inheritdoc />
62-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
63-
public float RepeatPenalty { get; set; } = 1.1f;
64-
65-
/// <inheritdoc />
66-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
67-
public int RepeatLastTokensCount { get; set; } = 64;
68-
69-
/// <inheritdoc />
70-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
71-
public float FrequencyPenalty { get; set; } = .0f;
72-
73-
/// <inheritdoc />
74-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
75-
public float PresencePenalty { get; set; } = .0f;
76-
77-
/// <inheritdoc />
78-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
79-
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
80-
81-
/// <inheritdoc />
82-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
83-
public float MirostatTau { get; set; } = 5.0f;
84-
85-
/// <inheritdoc />
86-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
87-
public float MirostatEta { get; set; } = 0.1f;
88-
89-
/// <inheritdoc />
90-
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
91-
public bool PenalizeNL { get; set; } = true;
92-
93-
/// <inheritdoc />
94-
public ISamplingPipeline? SamplingPipeline { get; set; }
32+
public ISamplingPipeline SamplingPipeline { get; set; } = new DefaultSamplingPipeline();
9533
}
9634

9735
/// <summary>

0 commit comments

Comments
 (0)