Skip to content

Commit 477c591

Browse files
committed
Implemented very first vesrion of hyde.
1 parent c2d316c commit 477c591

File tree

2 files changed

+122
-9
lines changed

2 files changed

+122
-9
lines changed

src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,28 @@ public async Task RunSample2()
2727
{
2828
var services = new ServiceCollection();
2929

30-
var apiKey = Dotenv.Get("COHERE_API_KEY")!;
31-
var cohereBaseUrl = Dotenv.Get("COHERE_BASE_API_KEY");
32-
if (string.IsNullOrEmpty(cohereBaseUrl))
30+
var cohereAzureBaseUrl = Dotenv.Get("COHERE_AZURE_BASE_URL");
31+
if (string.IsNullOrEmpty(cohereAzureBaseUrl))
3332
{
33+
var apiKey = Dotenv.Get("COHERE_API_KEY")!;
3434
services.ConfigureCohereChat(apiKey);
3535
}
3636
else
3737
{
38-
services.ConfigureCohereChat(apiKey, cohereBaseUrl);
38+
var azureApiKey = Dotenv.Get("COHERE_AZURE_API_KEY");
39+
services.ConfigureCohereChat(azureApiKey, cohereAzureBaseUrl);
3940
}
4041
//verify if rerank has a different api key (because the apikey point on azure ai studio)
41-
var rerankApiKey = Dotenv.Get("COHERE_RERANK_API_KEY");
42-
if (string.IsNullOrEmpty(rerankApiKey))
42+
var rerankAzureBaseUrl = Dotenv.Get("COHERE_AZURE_RERANK_BASE_URL");
43+
if (string.IsNullOrEmpty(rerankAzureBaseUrl))
4344
{
45+
var apiKey = Dotenv.Get("COHERE_API_KEY")!;
4446
services.ConfigureCohereRerank(apiKey);
4547
}
4648
else
4749
{
48-
services.ConfigureCohereRerank(rerankApiKey);
50+
var azureReRankApiKey = Dotenv.Get("COHERE_AZURE_RERANK_API_KEY");
51+
services.ConfigureCohereRerank(azureReRankApiKey, rerankAzureBaseUrl);
4952
}
5053

5154
services.AddHttpClient<RawCohereChatClient>()
@@ -89,12 +92,15 @@ public async Task RunSample2()
8992
.Title("Select query rewriter")
9093
.AddChoices(["Semantic Kernel Base", "Semantic Kernel Handlebar"]));
9194

95+
var useHyde = AnsiConsole.Confirm("Do you want to use HyDe? (y/n)", false);
96+
9297
var kernelBuider = CreateBasicKernelBuilder();
9398
var builder = CreateBasicKernelMemoryBuilder(
9499
services,
95100
storageToUse == "elasticsearch",
96101
queryExecutorToUse,
97-
queryRewriterTool == "Semantic Kernel Handlebar");
102+
queryRewriterTool == "Semantic Kernel Handlebar",
103+
useHyde);
98104
var kernelMemory = builder.Build<MemoryServerless>();
99105
var kernel = kernelBuider.Build();
100106

@@ -238,7 +244,8 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
238244
ServiceCollection services,
239245
bool useElasticSearch,
240246
string ragToolToUse,
241-
bool useHandlebarQueryRewriter)
247+
bool useHandlebarQueryRewriter,
248+
bool useHyde)
242249
{
243250
// we need a series of services to use Kernel Memory, the first one is
244251
// an embedding service that will be used to create dense vector for
@@ -306,6 +313,12 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
306313
services.AddSingleton<HandlebarSemanticKernelQueryRewriter>();
307314
services.AddSingleton<SemanticKernelQueryRewriter>();
308315
services.AddSingleton<StandardVectorSearchQueryHandler>();
316+
services.AddSingleton<HyDeQueryHandler>();
317+
var hydeConfig = new HiDeQueryHandlerConfiguration()
318+
{
319+
Prompt = "Given a question, generate a paragraph of text that answers the question in the context of computer security and IT security"
320+
};
321+
services.AddSingleton(hydeConfig);
309322
services.AddSingleton<KeywordSearchQueryHandler>();
310323

311324
var rewriterOptions = new SemanticKernelQueryRewriterOptions();
@@ -337,6 +350,11 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
337350
config.AddHandler<KeywordSearchQueryHandler>();
338351
}
339352

353+
if (useHyde)
354+
{
355+
config.AddHandler<HyDeQueryHandler>();
356+
}
357+
340358
if (ragToolToUse == "Cohere CommandR+")
341359
{
342360
config.AddHandler<CohereCommandRQueryExecutor>();
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
using DocumentFormat.OpenXml.Packaging;
2+
using Microsoft.Extensions.Logging;
3+
using Microsoft.KernelMemory.Diagnostics;
4+
using Microsoft.KernelMemory.MemoryStorage;
5+
using Microsoft.SemanticKernel;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
11+
namespace KernelMemory.Extensions;
12+
13+
public class HiDeQueryHandlerConfiguration
14+
{
15+
public string Prompt { get; set; } = "Given a question, generate a paragraph of text that answers the question.";
16+
}
17+
18+
public class HyDeQueryHandler : BasicQueryHandler
19+
{
20+
private readonly IMemoryDb _memoryDb;
21+
private readonly Kernel _kernel;
22+
private readonly HiDeQueryHandlerConfiguration _configuration;
23+
private readonly ILogger<StandardVectorSearchQueryHandler> _log;
24+
25+
public override string Name => "StandardVectorSearchQueryHandler";
26+
27+
public HyDeQueryHandler(
28+
IMemoryDb memory,
29+
Kernel kernel,
30+
HiDeQueryHandlerConfiguration? configuration,
31+
ILogger<StandardVectorSearchQueryHandler>? log = null)
32+
{
33+
_memoryDb = memory;
34+
_kernel = kernel;
35+
_configuration = configuration ?? new HiDeQueryHandlerConfiguration();
36+
_log = log ?? DefaultLogger<StandardVectorSearchQueryHandler>.Instance;
37+
}
38+
39+
/// <summary>
40+
/// Perform a vector search in default memory using the hyde principle.
41+
/// </summary>
42+
/// <param name="userQuestion"></param>
43+
/// <param name="cancellationToken"></param>
44+
/// <returns></returns>
45+
protected override async Task OnHandleAsync(UserQuestion userQuestion, CancellationToken cancellationToken)
46+
{
47+
// Perform a vector search in default memory
48+
StringBuilder prompt = new StringBuilder(_configuration.Prompt.Length + userQuestion.Question.Length + "Question: ".Length + "Paragraph: ".Length + 20);
49+
prompt.AppendLine(_configuration.Prompt);
50+
prompt.AppendLine("Question: " + userQuestion.Question);
51+
prompt.AppendLine("Paragraph: ");
52+
53+
var result = await _kernel.InvokePromptAsync(prompt.ToString(), cancellationToken: cancellationToken);
54+
55+
var paragraph = result.ToString();
56+
57+
var list = new List<(MemoryRecord memory, double relevance)>();
58+
59+
IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync(
60+
index: userQuestion.UserQueryOptions.Index,
61+
text: paragraph,
62+
filters: userQuestion.Filters,
63+
minRelevance: userQuestion.UserQueryOptions.MinRelevance,
64+
limit: userQuestion.UserQueryOptions.RetrievalQueryLimit,
65+
withEmbeddings: false,
66+
cancellationToken: cancellationToken);
67+
68+
// Memories are sorted by relevance, starting from the most relevant
69+
await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false))
70+
{
71+
list.Add((memory, relevance));
72+
}
73+
74+
var records = new List<MemoryRecord>();
75+
// Memories are sorted by relevance, starting from the most relevant
76+
foreach ((MemoryRecord memory, double relevance) in list)
77+
{
78+
var partitionText = memory.GetPartitionText(this._log).Trim();
79+
if (string.IsNullOrEmpty(partitionText))
80+
{
81+
this._log.LogError("The document partition is empty, doc: {0}", memory.Id);
82+
continue;
83+
}
84+
85+
if (relevance > float.MinValue)
86+
{
87+
this._log.LogTrace("Adding result with relevance {0}", relevance);
88+
records.Add(memory);
89+
}
90+
}
91+
92+
//ok now that you have all the memory record and citations, add to the object
93+
userQuestion.AddMemoryRecordSource("hyde-vector-search", records);
94+
}
95+
}

0 commit comments

Comments
 (0)