Skip to content

Commit d79948e

Browse files
committed
Almost completed openairag query executor
1 parent cb4ecfe commit d79948e

File tree

4 files changed

+292
-6
lines changed

4 files changed

+292
-6
lines changed

src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ public async Task RunSample2()
8080

8181
var queryExecutorToUse = AnsiConsole.Prompt(new SelectionPrompt<string>()
8282
.Title("Select the query executor to use")
83-
.AddChoices(["KernelMemory Default", "Cohere CommandR+"]));
83+
.AddChoices([
84+
"KernelMemory Default",
85+
"Cohere CommandR+",
86+
"OpenAI Tool"]));
8487

8588
var queryRewriterTool = AnsiConsole.Prompt(new SelectionPrompt<string>()
8689
.Title("Select query rewriter")
@@ -90,7 +93,7 @@ public async Task RunSample2()
9093
var builder = CreateBasicKernelMemoryBuilder(
9194
services,
9295
storageToUse == "elasticsearch",
93-
queryExecutorToUse == "Cohere CommandR+",
96+
queryExecutorToUse,
9497
queryRewriterTool == "Semantic Kernel Handlebar");
9598
var kernelMemory = builder.Build<MemoryServerless>();
9699
var kernel = kernelBuider.Build();
@@ -136,6 +139,7 @@ public async Task RunSample2()
136139
var questionEnumerator = questionPipeline.ExecuteQueryAsync(userQuestion);
137140

138141
Console.WriteLine("\nAnswerStream:\n");
142+
int segments = 0;
139143
await foreach (var step in questionEnumerator)
140144
{
141145
if (shouldDumpRewrittenQuery)
@@ -146,9 +150,16 @@ public async Task RunSample2()
146150
if (step.Type == UserQuestionProgressType.AnswerPart)
147151
{
148152
Console.Write(step.Text);
153+
segments++;
149154
}
150155
}
151156

157+
//ok we really have streaming result?^
158+
if (segments == 0)
159+
{
160+
//ok we have no streaming, so we need to get the whole answer.
161+
Console.Write(userQuestion.Answer);
162+
}
152163
Console.WriteLine("\n\n");
153164

154165
//ok we can validate the answer if requested
@@ -176,7 +187,7 @@ public async Task RunSample2()
176187

177188
private static async Task ManageIndexingOfDocuments(MemoryServerless kernelMemory)
178189
{
179-
var indexDocument = AnsiConsole.Confirm("Do you want to index documents? (y/n)", true);
190+
var indexDocument = AnsiConsole.Confirm("Do you want to index documents? (y/n)", false);
180191
if (indexDocument)
181192
{
182193
var singleDocumentIdex = AnsiConsole.Confirm("Do you want to index a single document? (y/n)", true);
@@ -226,7 +237,7 @@ private static async Task IndexDocument(MemoryServerless kernelMemory, string do
226237
private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
227238
ServiceCollection services,
228239
bool useElasticSearch,
229-
bool useCohereCommandRPlusForQueryExecutor,
240+
string ragToolToUse,
230241
bool useHandlebarQueryRewriter)
231242
{
232243
// we need a series of services to use Kernel Memory, the first one is
@@ -305,6 +316,17 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
305316
services.AddSingleton<CohereCommandRQueryExecutor>();
306317
services.AddSingleton<StandardRagQueryExecutor>();
307318

319+
//register openai RAG component
320+
var openaiRagQueryExecutorConfiguration = new OpenAIRagQueryExecutorConfiguration()
321+
{
322+
MaxTokens = 8000,
323+
Temperature = 0.0,
324+
ModelId = "gpt4o",
325+
ModelName = "gpt-4o" //important it will determine the tokenizer
326+
};
327+
services.AddSingleton(openaiRagQueryExecutorConfiguration);
328+
services.AddSingleton<OpenaiRagQueryExecutor>();
329+
308330
//now create the pipeline
309331
services.AddKernelMemoryUserQuestionPipeline(config =>
310332
{
@@ -315,10 +337,14 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
315337
config.AddHandler<KeywordSearchQueryHandler>();
316338
}
317339

318-
if (useCohereCommandRPlusForQueryExecutor)
340+
if (ragToolToUse == "Cohere CommandR+")
319341
{
320342
config.AddHandler<CohereCommandRQueryExecutor>();
321343
}
344+
else if (ragToolToUse == "OpenAI Tool")
345+
{
346+
config.AddHandler<OpenaiRagQueryExecutor>();
347+
}
322348
else
323349
{
324350
config.AddHandler<StandardRagQueryExecutor>();

src/KernelMemory.Extensions/QueryPipeline/MemoryRecordHelper.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace KernelMemory.Extensions.QueryPipeline
88
{
99
internal static class MemoryRecordHelper
1010
{
11-
internal static IReadOnlyCollection<Citation>? BuildCitations(
11+
internal static IReadOnlyCollection<Citation> BuildCitations(
1212
List<MemoryRecord> usedMemoryRecord,
1313
string index,
1414
ILogger logger)
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.KernelMemory.Diagnostics;
3+
using Microsoft.KernelMemory.MemoryStorage;
4+
using Microsoft.ML.Tokenizers;
5+
using Microsoft.SemanticKernel;
6+
using Microsoft.SemanticKernel.Connectors.OpenAI;
7+
using System;
8+
using System.Collections.Generic;
9+
using System.ComponentModel;
10+
using System.Diagnostics;
11+
using System.Linq;
12+
using System.Text;
13+
using System.Text.Json;
14+
using System.Threading;
15+
using System.Threading.Tasks;
16+
17+
namespace KernelMemory.Extensions.QueryPipeline;
18+
19+
public class OpenAIRagQueryExecutorConfiguration
20+
{
21+
/// <summary>
22+
/// This is the name of the OpenAI model used to
23+
/// create the correct tokenizer. If you are using
24+
/// the default <see cref="ModelId"/> automatically
25+
/// we will use a standard gpt3.5 model
26+
/// </summary>
27+
public string ModelName { get; set; } = "gpt35";
28+
29+
/// <summary>
30+
/// This is the modelId configured in Semantic Kernel
31+
/// that helps using the correct model. If it is null
32+
/// we will use the default model
33+
/// </summary>
34+
public string? ModelId { get; set; }
35+
36+
/// <summary>
37+
/// Max tokens in the request, if not specified 3000 tokens
38+
/// will be used
39+
/// </summary>
40+
public int MaxTokens { get; set; } = 3000;
41+
42+
/// <summary>
43+
/// Temperature of the request.
44+
/// </summary>
45+
public double Temperature { get; set; }
46+
47+
/// <summary>
48+
/// If GPT returns no citation we can remove the answer.
49+
/// </summary>
50+
public bool RemoveAnswerIfNoCitations { get; set; } = false;
51+
}
52+
53+
/// <summary>
54+
/// Executes the query part, it will start with a predefined prompt
55+
/// and then add in the prompt all the retrieved memories as fact and
56+
/// then use the LLM to answer user query with the fact (grounding).
57+
/// </summary>
58+
public class OpenaiRagQueryExecutor : BasicQueryHandler
59+
{
60+
public override string Name => "OpenaiRagQueryExecutor";
61+
62+
private readonly Kernel _kernel;
63+
private readonly OpenAIRagQueryExecutorConfiguration _config;
64+
private readonly Tokenizer _tokenizer;
65+
private readonly ILogger<StandardRagQueryExecutor> _log;
66+
67+
public OpenaiRagQueryExecutor(
68+
Kernel kernel,
69+
OpenAIRagQueryExecutorConfiguration? config = null,
70+
ILogger<StandardRagQueryExecutor>? log = null)
71+
{
72+
_kernel = kernel;
73+
_config = config ?? new OpenAIRagQueryExecutorConfiguration();
74+
_tokenizer = Tiktoken.CreateTiktokenForModel(_config.ModelName);
75+
_log = log ?? DefaultLogger<StandardRagQueryExecutor>.Instance;
76+
}
77+
78+
protected override async Task OnHandleAsync(
79+
UserQuestion userQuestion,
80+
CancellationToken cancellationToken)
81+
{
82+
var memoryRecords = await userQuestion.GetMemoryOrdered();
83+
if (memoryRecords.Count == 0)
84+
{
85+
//Well we have no memory we can simply return.
86+
return;
87+
}
88+
89+
//This code is taken and modified from the original KernelMemory codebase
90+
//Create a base of facts in a stringbuilder.
91+
var facts = new StringBuilder();
92+
93+
//Then we need to stop adding facts when we reach the max number of tokens
94+
var tokensAvailable = _config.MaxTokens - _tokenizer.CountTokens(userQuestion.Question);
95+
96+
//TODO: Add the preambole of the prompt to token count
97+
98+
//Some statistics to tell how many facts we have available and how many we used.
99+
int factsAvailableCount = memoryRecords.Count;
100+
int factsUsedCount = 0;
101+
102+
//we need to get the list of all memory record used, because we will need them
103+
//to build citations.
104+
List<MemoryRecord> memoryRecordToUse = new();
105+
int docNumber = 0;
106+
foreach (var mr in memoryRecords)
107+
{
108+
factsAvailableCount++;
109+
var partitionText = mr.GetPartitionText();
110+
111+
var size = _tokenizer.CountTokens(partitionText);
112+
if (size >= tokensAvailable)
113+
{
114+
// Stop after reaching the max number of tokens
115+
break;
116+
}
117+
118+
factsUsedCount++;
119+
120+
//Create a special format for the fact.
121+
var fact = $"---\nDocument {++docNumber}:\n{partitionText}\n";
122+
123+
facts.Append(fact);
124+
memoryRecordToUse.Add(mr);
125+
tokensAvailable -= size;
126+
}
127+
128+
if (factsAvailableCount > 0 && factsUsedCount == 0)
129+
{
130+
_log.LogError("Unable to inject memories in the prompt, not enough tokens available");
131+
return;
132+
}
133+
134+
if (factsUsedCount == 0)
135+
{
136+
_log.LogWarning("No memories available");
137+
return;
138+
}
139+
var watch = new Stopwatch();
140+
watch.Restart();
141+
142+
var openaiAnswer = await GenerateAnswerAsync(userQuestion.Question, facts.ToString(), cancellationToken);
143+
144+
if (openaiAnswer == null)
145+
{
146+
//now answer is possible, then we can let the question flow to another handler.
147+
return;
148+
}
149+
150+
//ok now we want to add answer and citations.
151+
watch.Stop();
152+
_log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds);
153+
154+
userQuestion.Answer = openaiAnswer.Answer;
155+
156+
List<MemoryRecord> usedMemoryRecord = memoryRecordToUse
157+
.Where((_, i) => openaiAnswer.Documents.Contains(i))
158+
.ToList();
159+
160+
// now we need to clean up the citations, including only the one used to answer the question
161+
userQuestion.Citations = MemoryRecordHelper.BuildCitations(usedMemoryRecord, userQuestion.UserQueryOptions.Index, this._log);
162+
163+
// ground if needed
164+
if (_config.RemoveAnswerIfNoCitations && userQuestion.Citations!.Count == 0)
165+
{
166+
//no answer is possible, because we do not have citations.
167+
userQuestion.Answer = null;
168+
}
169+
}
170+
171+
private class GptAnswer
172+
{
173+
public string Answer { get; set; }
174+
175+
public HashSet<int> Documents { get; set; }
176+
}
177+
178+
/// <summary>
179+
/// Perform the call to the OpenAI API to get the answer to the question,
180+
/// it uses tool call to coerce the answer to not only return the answer text
181+
/// but also the array of documents that are used to answer the query
182+
/// </summary>
183+
/// <param name="question"></param>
184+
/// <param name="documents"></param>
185+
/// <returns></returns>
186+
private async Task<GptAnswer?> GenerateAnswerAsync(
187+
string question,
188+
string documents,
189+
CancellationToken token)
190+
{
191+
//First step is creating the function in Semantic Kernel.
192+
var function = KernelFunctionFactory.CreateFromMethod(
193+
[Description("Return the result to the user")] (
194+
[Description("Answer of the question")] string answer,
195+
[Description("Documents used to formulate the answer")] int[] documents
196+
) =>
197+
{
198+
}, "return_result");
199+
var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]);
200+
var openAIFunction = plugin.GetFunctionsMetadata().First().ToOpenAIFunction();
201+
202+
// Create a template for chat with settings
203+
var chat = _kernel.CreateFunctionFromPrompt(new PromptTemplateConfig()
204+
{
205+
Name = "Rag",
206+
Description = "Answer user question with documents.",
207+
Template = @"You are an AI assistant that helps users answer questions given a specific context. You will be given a context and asked a question based on that context. Your answer should be as precise as possible and should only come from the context.
208+
Please add all documents used as citations.
209+
Question: {{$question}}
210+
211+
Documents:
212+
{{$documents}}",
213+
TemplateFormat = "semantic-kernel",
214+
InputVariables =
215+
[
216+
new() { Name = "question", Description = "Question of the user.", IsRequired = true },
217+
new() { Name = "documents", Description = "Documents needed to answer the query.", IsRequired = true }
218+
],
219+
ExecutionSettings =
220+
{
221+
{ "default", new OpenAIPromptExecutionSettings()
222+
{
223+
MaxTokens = 1000,
224+
Temperature = 0,
225+
ModelId = _config.ModelId,
226+
ChatSystemPrompt = "You will answer question of the user using only documents in the prompt",
227+
ToolCallBehavior = ToolCallBehavior.RequireFunction(openAIFunction, false),
228+
}
229+
},
230+
}
231+
});
232+
233+
KernelArguments ka = new();
234+
ka["question"] = question;
235+
ka["documents"] = documents;
236+
var result = await _kernel.InvokeAsync(chat, ka, token);
237+
238+
var openaiMessageContent = result.GetValue<OpenAIChatMessageContent>();
239+
if (result is FunctionResult fre)
240+
{
241+
var toolCall = openaiMessageContent.GetOpenAIFunctionToolCalls().Single();
242+
var answer = ((JsonElement)toolCall.Arguments["answer"]).GetString()!;
243+
244+
//-1 is because GPT is 1 based with document.
245+
var citations = ((JsonElement)toolCall.Arguments["documents"]).EnumerateArray().Select(e => e.GetInt32() - 1).ToHashSet();
246+
return new GptAnswer()
247+
{
248+
Answer = answer,
249+
Documents = citations
250+
};
251+
}
252+
253+
return null;
254+
}
255+
}

src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616

1717
namespace KernelMemory.Extensions
1818
{
19+
/// <summary>
20+
/// Executes the query part, it will start with a predefined prompt
21+
/// and then add in the prompt all the retrieved memories as fact and
22+
/// then use the LLM to answer user query with the fact (grounding).
23+
/// </summary>
1924
public class StandardRagQueryExecutor : BasicAsyncQueryHandlerWithProgress
2025
{
2126
public override string Name => "StandardRagQueryExecutor";

0 commit comments

Comments
 (0)