|
| 1 | +using Microsoft.Extensions.Logging; |
| 2 | +using Microsoft.KernelMemory.Diagnostics; |
| 3 | +using Microsoft.KernelMemory.MemoryStorage; |
| 4 | +using Microsoft.ML.Tokenizers; |
| 5 | +using Microsoft.SemanticKernel; |
| 6 | +using Microsoft.SemanticKernel.Connectors.OpenAI; |
| 7 | +using System; |
| 8 | +using System.Collections.Generic; |
| 9 | +using System.ComponentModel; |
| 10 | +using System.Diagnostics; |
| 11 | +using System.Linq; |
| 12 | +using System.Text; |
| 13 | +using System.Text.Json; |
| 14 | +using System.Threading; |
| 15 | +using System.Threading.Tasks; |
| 16 | + |
| 17 | +namespace KernelMemory.Extensions.QueryPipeline; |
| 18 | + |
| 19 | +public class OpenAIRagQueryExecutorConfiguration |
| 20 | +{ |
| 21 | + /// <summary> |
| 22 | + /// This is the name of the OpenAI model used to |
| 23 | + /// create the correct tokenizer. If you are using |
| 24 | + /// the default <see cref="ModelId"/> automatically |
| 25 | + /// we will use a standard gpt3.5 model |
| 26 | + /// </summary> |
| 27 | + public string ModelName { get; set; } = "gpt35"; |
| 28 | + |
| 29 | + /// <summary> |
| 30 | + /// This is the modelId configured in Semantic Kernel |
| 31 | + /// that helps using the correct model. If it is null |
| 32 | + /// we will use the default model |
| 33 | + /// </summary> |
| 34 | + public string? ModelId { get; set; } |
| 35 | + |
| 36 | + /// <summary> |
| 37 | + /// Max tokens in the request, if not specified 3000 tokens |
| 38 | + /// will be used |
| 39 | + /// </summary> |
| 40 | + public int MaxTokens { get; set; } = 3000; |
| 41 | + |
| 42 | + /// <summary> |
| 43 | + /// Temperature of the request. |
| 44 | + /// </summary> |
| 45 | + public double Temperature { get; set; } |
| 46 | + |
| 47 | + /// <summary> |
| 48 | + /// If GPT returns no citation we can remove the answer. |
| 49 | + /// </summary> |
| 50 | + public bool RemoveAnswerIfNoCitations { get; set; } = false; |
| 51 | +} |
| 52 | + |
| 53 | +/// <summary> |
| 54 | +/// Executes the query part, it will start with a predefined prompt |
| 55 | +/// and then add in the prompt all the retrieved memories as fact and |
| 56 | +/// then use the LLM to answer user query with the fact (grounding). |
| 57 | +/// </summary> |
| 58 | +public class OpenaiRagQueryExecutor : BasicQueryHandler |
| 59 | +{ |
| 60 | + public override string Name => "OpenaiRagQueryExecutor"; |
| 61 | + |
| 62 | + private readonly Kernel _kernel; |
| 63 | + private readonly OpenAIRagQueryExecutorConfiguration _config; |
| 64 | + private readonly Tokenizer _tokenizer; |
| 65 | + private readonly ILogger<StandardRagQueryExecutor> _log; |
| 66 | + |
| 67 | + public OpenaiRagQueryExecutor( |
| 68 | + Kernel kernel, |
| 69 | + OpenAIRagQueryExecutorConfiguration? config = null, |
| 70 | + ILogger<StandardRagQueryExecutor>? log = null) |
| 71 | + { |
| 72 | + _kernel = kernel; |
| 73 | + _config = config ?? new OpenAIRagQueryExecutorConfiguration(); |
| 74 | + _tokenizer = Tiktoken.CreateTiktokenForModel(_config.ModelName); |
| 75 | + _log = log ?? DefaultLogger<StandardRagQueryExecutor>.Instance; |
| 76 | + } |
| 77 | + |
| 78 | + protected override async Task OnHandleAsync( |
| 79 | + UserQuestion userQuestion, |
| 80 | + CancellationToken cancellationToken) |
| 81 | + { |
| 82 | + var memoryRecords = await userQuestion.GetMemoryOrdered(); |
| 83 | + if (memoryRecords.Count == 0) |
| 84 | + { |
| 85 | + //Well we have no memory we can simply return. |
| 86 | + return; |
| 87 | + } |
| 88 | + |
| 89 | + //This code is taken and modified from the original KernelMemory codebase |
| 90 | + //Create a base of facts in a stringbuilder. |
| 91 | + var facts = new StringBuilder(); |
| 92 | + |
| 93 | + //Then we need to stop adding facts when we reach the max number of tokens |
| 94 | + var tokensAvailable = _config.MaxTokens - _tokenizer.CountTokens(userQuestion.Question); |
| 95 | + |
| 96 | + //TODO: Add the preambole of the prompt to token count |
| 97 | + |
| 98 | + //Some statistics to tell how many facts we have available and how many we used. |
| 99 | + int factsAvailableCount = memoryRecords.Count; |
| 100 | + int factsUsedCount = 0; |
| 101 | + |
| 102 | + //we need to get the list of all memory record used, because we will need them |
| 103 | + //to build citations. |
| 104 | + List<MemoryRecord> memoryRecordToUse = new(); |
| 105 | + int docNumber = 0; |
| 106 | + foreach (var mr in memoryRecords) |
| 107 | + { |
| 108 | + factsAvailableCount++; |
| 109 | + var partitionText = mr.GetPartitionText(); |
| 110 | + |
| 111 | + var size = _tokenizer.CountTokens(partitionText); |
| 112 | + if (size >= tokensAvailable) |
| 113 | + { |
| 114 | + // Stop after reaching the max number of tokens |
| 115 | + break; |
| 116 | + } |
| 117 | + |
| 118 | + factsUsedCount++; |
| 119 | + |
| 120 | + //Create a special format for the fact. |
| 121 | + var fact = $"---\nDocument {++docNumber}:\n{partitionText}\n"; |
| 122 | + |
| 123 | + facts.Append(fact); |
| 124 | + memoryRecordToUse.Add(mr); |
| 125 | + tokensAvailable -= size; |
| 126 | + } |
| 127 | + |
| 128 | + if (factsAvailableCount > 0 && factsUsedCount == 0) |
| 129 | + { |
| 130 | + _log.LogError("Unable to inject memories in the prompt, not enough tokens available"); |
| 131 | + return; |
| 132 | + } |
| 133 | + |
| 134 | + if (factsUsedCount == 0) |
| 135 | + { |
| 136 | + _log.LogWarning("No memories available"); |
| 137 | + return; |
| 138 | + } |
| 139 | + var watch = new Stopwatch(); |
| 140 | + watch.Restart(); |
| 141 | + |
| 142 | + var openaiAnswer = await GenerateAnswerAsync(userQuestion.Question, facts.ToString(), cancellationToken); |
| 143 | + |
| 144 | + if (openaiAnswer == null) |
| 145 | + { |
| 146 | + //now answer is possible, then we can let the question flow to another handler. |
| 147 | + return; |
| 148 | + } |
| 149 | + |
| 150 | + //ok now we want to add answer and citations. |
| 151 | + watch.Stop(); |
| 152 | + _log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds); |
| 153 | + |
| 154 | + userQuestion.Answer = openaiAnswer.Answer; |
| 155 | + |
| 156 | + List<MemoryRecord> usedMemoryRecord = memoryRecordToUse |
| 157 | + .Where((_, i) => openaiAnswer.Documents.Contains(i)) |
| 158 | + .ToList(); |
| 159 | + |
| 160 | + // now we need to clean up the citations, including only the one used to answer the question |
| 161 | + userQuestion.Citations = MemoryRecordHelper.BuildCitations(usedMemoryRecord, userQuestion.UserQueryOptions.Index, this._log); |
| 162 | + |
| 163 | + // ground if needed |
| 164 | + if (_config.RemoveAnswerIfNoCitations && userQuestion.Citations!.Count == 0) |
| 165 | + { |
| 166 | + //no answer is possible, because we do not have citations. |
| 167 | + userQuestion.Answer = null; |
| 168 | + } |
| 169 | + } |
| 170 | + |
| 171 | + private class GptAnswer |
| 172 | + { |
| 173 | + public string Answer { get; set; } |
| 174 | + |
| 175 | + public HashSet<int> Documents { get; set; } |
| 176 | + } |
| 177 | + |
| 178 | + /// <summary> |
| 179 | + /// Perform the call to the OpenAI API to get the answer to the question, |
| 180 | + /// it uses tool call to coerce the answer to not only return the answer text |
| 181 | + /// but also the array of documents that are used to answer the query |
| 182 | + /// </summary> |
| 183 | + /// <param name="question"></param> |
| 184 | + /// <param name="documents"></param> |
| 185 | + /// <returns></returns> |
| 186 | + private async Task<GptAnswer?> GenerateAnswerAsync( |
| 187 | + string question, |
| 188 | + string documents, |
| 189 | + CancellationToken token) |
| 190 | + { |
| 191 | + //First step is creating the function in Semantic Kernel. |
| 192 | + var function = KernelFunctionFactory.CreateFromMethod( |
| 193 | + [Description("Return the result to the user")] ( |
| 194 | + [Description("Answer of the question")] string answer, |
| 195 | + [Description("Documents used to formulate the answer")] int[] documents |
| 196 | + ) => |
| 197 | + { |
| 198 | + }, "return_result"); |
| 199 | + var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]); |
| 200 | + var openAIFunction = plugin.GetFunctionsMetadata().First().ToOpenAIFunction(); |
| 201 | + |
| 202 | + // Create a template for chat with settings |
| 203 | + var chat = _kernel.CreateFunctionFromPrompt(new PromptTemplateConfig() |
| 204 | + { |
| 205 | + Name = "Rag", |
| 206 | + Description = "Answer user question with documents.", |
| 207 | + Template = @"You are an AI assistant that helps users answer questions given a specific context. You will be given a context and asked a question based on that context. Your answer should be as precise as possible and should only come from the context. |
| 208 | +Please add all documents used as citations. |
| 209 | +Question: {{$question}} |
| 210 | +
|
| 211 | +Documents: |
| 212 | +{{$documents}}", |
| 213 | + TemplateFormat = "semantic-kernel", |
| 214 | + InputVariables = |
| 215 | + [ |
| 216 | + new() { Name = "question", Description = "Question of the user.", IsRequired = true }, |
| 217 | + new() { Name = "documents", Description = "Documents needed to answer the query.", IsRequired = true } |
| 218 | + ], |
| 219 | + ExecutionSettings = |
| 220 | + { |
| 221 | + { "default", new OpenAIPromptExecutionSettings() |
| 222 | + { |
| 223 | + MaxTokens = 1000, |
| 224 | + Temperature = 0, |
| 225 | + ModelId = _config.ModelId, |
| 226 | + ChatSystemPrompt = "You will answer question of the user using only documents in the prompt", |
| 227 | + ToolCallBehavior = ToolCallBehavior.RequireFunction(openAIFunction, false), |
| 228 | + } |
| 229 | + }, |
| 230 | + } |
| 231 | + }); |
| 232 | + |
| 233 | + KernelArguments ka = new(); |
| 234 | + ka["question"] = question; |
| 235 | + ka["documents"] = documents; |
| 236 | + var result = await _kernel.InvokeAsync(chat, ka, token); |
| 237 | + |
| 238 | + var openaiMessageContent = result.GetValue<OpenAIChatMessageContent>(); |
| 239 | + if (result is FunctionResult fre) |
| 240 | + { |
| 241 | + var toolCall = openaiMessageContent.GetOpenAIFunctionToolCalls().Single(); |
| 242 | + var answer = ((JsonElement)toolCall.Arguments["answer"]).GetString()!; |
| 243 | + |
| 244 | + //-1 is because GPT is 1 based with document. |
| 245 | + var citations = ((JsonElement)toolCall.Arguments["documents"]).EnumerateArray().Select(e => e.GetInt32() - 1).ToHashSet(); |
| 246 | + return new GptAnswer() |
| 247 | + { |
| 248 | + Answer = answer, |
| 249 | + Documents = citations |
| 250 | + }; |
| 251 | + } |
| 252 | + |
| 253 | + return null; |
| 254 | + } |
| 255 | +} |
0 commit comments