Skip to content

Commit ac902ad

Browse files
authored
Merge pull request #1202 from OoLunar/oolunar/preserve-cut-off-text
Preserve cut off generated text
2 parents e52f5ef + 2b74c17 commit ac902ad

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

LLama/ChatSession.cs

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.IO;
44
using System.Linq;
55
using System.Runtime.CompilerServices;
6+
using System.Text;
67
using System.Text.Json;
78
using System.Threading;
89
using System.Threading.Tasks;
@@ -171,7 +172,7 @@ public SessionState GetSessionState()
171172
{
172173
var executorState = ((StatefulExecutorBase)Executor).GetStateData();
173174
return new SessionState(
174-
executorState.PastTokensCount > 0
175+
executorState.PastTokensCount > 0
175176
? Executor.Context.GetState() : null,
176177
executorState,
177178
History,
@@ -227,7 +228,7 @@ public void LoadSession(string path, bool loadTransforms = true)
227228
if (state.ExecutorState is null)
228229
{
229230
var executorPath = Path.Combine(path, EXECUTOR_STATE_FILENAME);
230-
((StatefulExecutorBase) Executor).LoadState(filename: executorPath);
231+
((StatefulExecutorBase)Executor).LoadState(filename: executorPath);
231232
}
232233
LoadSession(state, loadTransforms);
233234
}
@@ -441,21 +442,21 @@ public async IAsyncEnumerable<string> ChatAsync(
441442
prompt = HistoryTransform.HistoryToText(singleMessageHistory);
442443
}
443444

444-
string assistantMessage = string.Empty;
445+
StringBuilder assistantMessage = new();
445446

446-
await foreach (
447-
string textToken
448-
in ChatAsyncInternal(
449-
prompt,
450-
inferenceParams,
451-
cancellationToken))
447+
try
452448
{
453-
assistantMessage += textToken;
454-
yield return textToken;
449+
await foreach (var textToken in ChatAsyncInternal(prompt, inferenceParams, cancellationToken))
450+
{
451+
assistantMessage.Append(textToken);
452+
yield return textToken;
453+
}
454+
}
455+
finally
456+
{
457+
// Add the assistant message to the history
458+
AddAssistantMessage(assistantMessage.ToString());
455459
}
456-
457-
// Add the assistant message to the history
458-
AddAssistantMessage(assistantMessage);
459460
}
460461

461462
/// <summary>
@@ -624,7 +625,7 @@ public record SessionState
624625
/// <summary>
625626
/// The input transform pipeline used in this session.
626627
/// </summary>
627-
public ITextTransform[] InputTransformPipeline { get; set; } = [ ];
628+
public ITextTransform[] InputTransformPipeline { get; set; } = [];
628629

629630
/// <summary>
630631
/// The output transform used in this session.
@@ -635,11 +636,11 @@ public record SessionState
635636
/// The history transform used in this session.
636637
/// </summary>
637638
public IHistoryTransform HistoryTransform { get; set; } = new LLamaTransforms.DefaultHistoryTransform();
638-
639+
639640
/// <summary>
640641
/// The chat history messages for this session.
641642
/// </summary>
642-
public ChatHistory.Message[] History { get; set; } = [ ];
643+
public ChatHistory.Message[] History { get; set; } = [];
643644

644645
/// <summary>
645646
/// Create a new session state.
@@ -651,7 +652,7 @@ public record SessionState
651652
/// <param name="outputTransform"></param>
652653
/// <param name="historyTransform"></param>
653654
public SessionState(
654-
State? contextState, ExecutorBaseState executorState,
655+
State? contextState, ExecutorBaseState executorState,
655656
ChatHistory history, List<ITextTransform> inputTransformPipeline,
656657
ITextStreamTransform outputTransform, IHistoryTransform historyTransform)
657658
{
@@ -738,22 +739,22 @@ public static SessionState Load(string path)
738739
ITextTransform[] inputTransforms;
739740
try
740741
{
741-
inputTransforms = File.Exists(inputTransformFilepath) ?
742+
inputTransforms = File.Exists(inputTransformFilepath) ?
742743
(JsonSerializer.Deserialize<ITextTransform[]>(File.ReadAllText(inputTransformFilepath))
743744
?? throw new ArgumentException("Input transform file is invalid", nameof(path)))
744-
: [ ];
745+
: [];
745746
}
746747
catch (JsonException)
747748
{
748749
throw new ArgumentException("Input transform file is invalid", nameof(path));
749750
}
750751

751752
string outputTransformFilepath = Path.Combine(path, ChatSession.OUTPUT_TRANSFORM_FILENAME);
752-
753+
753754
ITextStreamTransform outputTransform;
754755
try
755756
{
756-
outputTransform = File.Exists(outputTransformFilepath) ?
757+
outputTransform = File.Exists(outputTransformFilepath) ?
757758
(JsonSerializer.Deserialize<ITextStreamTransform>(File.ReadAllText(outputTransformFilepath))
758759
?? throw new ArgumentException("Output transform file is invalid", nameof(path)))
759760
: new LLamaTransforms.EmptyTextOutputStreamTransform();
@@ -767,7 +768,7 @@ public static SessionState Load(string path)
767768
IHistoryTransform historyTransform;
768769
try
769770
{
770-
historyTransform = File.Exists(historyTransformFilepath) ?
771+
historyTransform = File.Exists(historyTransformFilepath) ?
771772
(JsonSerializer.Deserialize<IHistoryTransform>(File.ReadAllText(historyTransformFilepath))
772773
?? throw new ArgumentException("History transform file is invalid", nameof(path)))
773774
: new LLamaTransforms.DefaultHistoryTransform();

0 commit comments

Comments
 (0)