Issues with cancellation tokens in IChatCompletionService. #11146
-
Hi everyone, so quite a simple one possibly, but that one has me banging my head against the wall. I have this code: using System.Collections.Concurrent;
using System.Runtime.CompilerServices;
using Interfaces;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
public class LLMChatCompletionsService(
Kernel kernel) : ILLMChatCompletionsService
{
private static readonly ConcurrentDictionary<string, ChatHistory> __chatHistories = new();
private IChatCompletionService _chatService;
private OpenAIPromptExecutionSettings _executionSettings;
public OpenAIPromptExecutionSettings ExecutionSettings => _executionSettings ??= GetExecutionSettings();
public IChatCompletionService ChatService =>
_chatService ??= kernel.GetRequiredService<IChatCompletionService>();
public ChatHistory this[string chatHistoryId]
{
get => __chatHistories.GetOrAdd(chatHistoryId, _ => []);
set => __chatHistories[chatHistoryId] = value;
}
public ConcurrentDictionary<string, ChatHistory> ChatHistories => __chatHistories;
public async Task<string> GetChatResponseAsync(string chatHistoryId, string userMessage,
CancellationToken cancellationToken)
{
return await GetChatResponseAsync(GetOrCreateChatHistory(chatHistoryId), userMessage, cancellationToken);
}
public async Task<string> GetChatResponseAsync(ChatHistory chatHistory, string userMessage,
CancellationToken cancellationToken)
{
chatHistory.AddUserMessage(userMessage);
ChatMessageContent llmResponse =
await ChatService.GetChatMessageContentAsync(chatHistory, ExecutionSettings, kernel, cancellationToken);
chatHistory.Add(llmResponse);
return chatHistory.Last().Content;
}
public async IAsyncEnumerable<string> GetChatStreamResponseAsync(string chatHistoryId, string userMessage,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
ChatHistory chatHistory = GetOrCreateChatHistory(chatHistoryId);
await foreach (string chunk in GetChatStreamResponseAsync(chatHistory, userMessage, cancellationToken))
{
yield return chunk;
}
}
public async IAsyncEnumerable<string> GetChatStreamResponseAsync(ChatHistory chatHistory, string userMessage,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
chatHistory.AddUserMessage(userMessage);
await foreach (StreamingChatMessageContent streamingResponse in ChatService
.GetStreamingChatMessageContentsAsync(chatHistory)
.ConfigureAwait(false)
.WithCancellation(cancellationToken))
{
if (cancellationToken.IsCancellationRequested)
{
yield break;
}
// Yield each chunk as it arrives.
yield return streamingResponse.Content;
}
}
public ChatHistory CreateNewChatHistory(string chatHistoryId, string systemPrompt)
{
ChatHistory newHistory = __chatHistories[chatHistoryId] = new(systemPrompt);
return newHistory;
}
public ChatHistory GetOrCreateChatHistory(string chatHistoryId)
{
return __chatHistories.GetOrAdd(chatHistoryId, []);
}
private static OpenAIPromptExecutionSettings GetExecutionSettings()
{
return new()
{
//// The maximum number of tokens to generate in the response.
MaxTokens = 20000,
//// The temperature of the model. Must be between 0 and 1.
//// Higher values mean the model will take more risks.
Temperature = 0.5f
//// The maximum number of tokens to generate in the response.
//TopP = 1.0f,
//// The maximum number of tokens to generate in the response.
//FrequencyPenalty = 0.0f,
//// The maximum number of tokens to generate in the response.
//PresencePenalty = 0.0f
};
}
} So nothing crazy, and it works ok. That is until I issue a cancellation token. at which point I am getting the TaskCanceledException, as expected. The actual issue is that the exception happens on the line: ChatService
.GetStreamingChatMessageContentsAsync(chatHistory)
.ConfigureAwait(false)
.WithCancellation(cancellationToken) and for the life of me I don't understand why I cannot handle it gracefully or why the safeguard: if (cancellationToken.IsCancellationRequested)
{
yield break;
} if completely ignored. Am I doing something wrong here? Thank you in advance for your help & insights. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
@SergeyMenshykh can you take a look at this one? |
Beta Was this translation helpful? Give feedback.
-
Hi @MenaceFOP3, I think the reason the If, for your scenario, you need to detect cancellation via public async IAsyncEnumerable<string> GetChatStreamResponseAsync(ChatHistory chatHistory, string userMessage, [EnumeratorCancellation] CancellationToken cancellationToken)
{
chatHistory.AddUserMessage(userMessage);
await using var enumerator = ChatService.GetStreamingChatMessageContentsAsync(chatHistory).GetAsyncEnumerator();
while (await enumerator.MoveNextAsync())
{
if (cancellationToken.IsCancellationRequested)
{
yield break;
}
yield return enumerator.Current.Content;
}
} |
Beta Was this translation helpful? Give feedback.
Hi @MenaceFOP3, I think the reason the
TaskCanceledException
is thrown by theWithCancellation
method and the expected code is not executed is that the enumerator created by theWithCancellation
method detects that cancellation was requested well before your code does and throws the exception. As a result, because of the exception, the execution control leaves theawait foreach (StreamingChatMessageContent streamingResponse in ChatService)
loop and never reaches theif (cancellationToken.IsCancellationRequested)
line.If, for your scenario, you need to detect cancellation via
cancellationToken.IsCancellationRequested
rather than a cancellation exception, consider this alternative implemen…