Skip to content

Commit 3643e92

Browse files
authored
Add knowledgebase and conversation APIs (Azure#46893)
1 parent 42e72d7 commit 3643e92

File tree

7 files changed

+351
-0
lines changed

7 files changed

+351
-0
lines changed

eng/Packages.Data.props

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
<PackageReference Update="Azure.Provisioning.KeyVault" Version="1.0.0" />
158158
<PackageReference Update="Azure.Provisioning.ServiceBus" Version="1.0.0" />
159159
<PackageReference Update="Azure.Provisioning.Storage" Version="1.0.0" />
160+
<PackageReference Update="Microsoft.Bcl.Numerics" Version="8.0.0" />
160161

161162
<!-- Other approved packages -->
162163
<PackageReference Update="Microsoft.Azure.Amqp" Version="2.6.7" />

sdk/provisioning/Azure.Provisioning.CloudMachine/api/Azure.Provisioning.CloudMachine.netstandard2.0.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,21 @@ namespace Azure.Provisioning.CloudMachine.OpenAI
138138
{
139139
public static partial class AzureOpenAIExtensions
140140
{
141+
public static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this Azure.Core.ClientWorkspace workspace) { throw null; }
142+
public static Azure.Provisioning.CloudMachine.OpenAI.OpenAIConversation CreateOpenAIConversation(this Azure.Core.ClientWorkspace workspace) { throw null; }
141143
public static OpenAI.Chat.ChatClient GetOpenAIChatClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
142144
public static OpenAI.Embeddings.EmbeddingClient GetOpenAIEmbeddingsClient(this Azure.Core.ClientWorkspace workspace) { throw null; }
143145
}
146+
public partial class EmbeddingKnowledgebase
147+
{
148+
internal EmbeddingKnowledgebase() { }
149+
public void Add(string fact) { }
150+
}
151+
public partial class OpenAIConversation
152+
{
153+
internal OpenAIConversation() { }
154+
public string Say(string message) { throw null; }
155+
}
144156
public partial class OpenAIFeature : Azure.Provisioning.CloudMachine.CloudMachineFeature
145157
{
146158
public OpenAIFeature(AiModel chatDeployment, AiModel? embeddingsDeployment = null) { }

sdk/provisioning/Azure.Provisioning.CloudMachine/src/Azure.Provisioning.CloudMachine.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
<PackageReference Include="Azure.Provisioning.EventGrid" />
2424
<PackageReference Include="Azure.Security.KeyVault.Secrets" />
2525
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" VersionOverride="8.0.0" />
26+
<PackageReference Include="Microsoft.Bcl.Numerics" />
2627
</ItemGroup>
2728

2829
</Project>
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using OpenAI.Embeddings;
7+
8+
namespace Azure.Provisioning.CloudMachine.OpenAI;
9+
10+
/// <summary>
11+
/// Represents a knowledgebase of facts represented by embeddings that can be used to find relevant facts based on a given text.
12+
/// </summary>
13+
public class EmbeddingKnowledgebase
14+
{
15+
private EmbeddingClient _client;
16+
private List<string> _factsToProcess = new List<string>();
17+
18+
private List<ReadOnlyMemory<float>> _vectors = new List<ReadOnlyMemory<float>>();
19+
private List<string> _facts = new List<string>();
20+
21+
internal EmbeddingKnowledgebase(EmbeddingClient client)
22+
{
23+
_client = client;
24+
}
25+
26+
/// <summary>
27+
/// Add a fact to the knowledgebase.
28+
/// </summary>
29+
/// <param name="fact">The fact to add.</param>
30+
public void Add(string fact)
31+
{
32+
ChunkAndAddToFactsToProcess(fact, 1000);
33+
ProcessUnprocessedFacts();
34+
}
35+
36+
internal List<Fact> FindRelevantFacts(string text, float threshold = 0.29f, int top = 3)
37+
{
38+
if (_factsToProcess.Count > 0)
39+
ProcessUnprocessedFacts();
40+
41+
ReadOnlySpan<float> textVector = ProcessFact(text).Span;
42+
43+
var results = new List<Fact>();
44+
var distances = new List<(float Distance, int Index)>();
45+
for (int index = 0; index < _vectors.Count; index++)
46+
{
47+
ReadOnlyMemory<float> dbVector = _vectors[index];
48+
float distance = 1.0f - CosineSimilarity(dbVector.Span, textVector);
49+
distances.Add((distance, index));
50+
}
51+
distances.Sort(((float D1, int I1) v1, (float D2, int I2) v2) => v1.D1.CompareTo(v2.D2));
52+
53+
top = Math.Min(top, distances.Count);
54+
for (int i = 0; i < top; i++)
55+
{
56+
var distance = distances[i].Distance;
57+
if (distance > threshold)
58+
break;
59+
var index = distances[i].Index;
60+
results.Add(new Fact(_facts[index], index));
61+
}
62+
return results;
63+
}
64+
65+
private static float CosineSimilarity(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
66+
{
67+
float dot = 0, xSumSquared = 0, ySumSquared = 0;
68+
69+
for (int i = 0; i < x.Length; i++)
70+
{
71+
dot += x[i] * y[i];
72+
xSumSquared += x[i] * x[i];
73+
ySumSquared += y[i] * y[i];
74+
}
75+
return dot / (MathF.Sqrt(xSumSquared) * MathF.Sqrt(ySumSquared));
76+
}
77+
78+
private void ProcessUnprocessedFacts()
79+
{
80+
if (_factsToProcess.Count == 0)
81+
{
82+
return;
83+
}
84+
var embeddings = _client.GenerateEmbeddings(_factsToProcess);
85+
86+
foreach (var embedding in embeddings.Value)
87+
{
88+
_vectors.Add(embedding.ToFloats());
89+
_facts.Add(_factsToProcess[embedding.Index]);
90+
}
91+
92+
_factsToProcess.Clear();
93+
}
94+
95+
private ReadOnlyMemory<float> ProcessFact(string fact)
96+
{
97+
var embedding = _client.GenerateEmbedding(fact);
98+
99+
return embedding.Value.ToFloats();
100+
}
101+
102+
internal void ChunkAndAddToFactsToProcess(string text, int chunkSize)
103+
{
104+
if (chunkSize <= 0)
105+
{
106+
throw new ArgumentException("Chunk size must be greater than zero.", nameof(chunkSize));
107+
}
108+
109+
int overlapSize = (int)(chunkSize * 0.15);
110+
int stepSize = chunkSize - overlapSize;
111+
ReadOnlySpan<char> textSpan = text.AsSpan();
112+
113+
for (int i = 0; i < text.Length; i += stepSize)
114+
{
115+
while (i > 0 && !char.IsWhiteSpace(textSpan[i]))
116+
{
117+
i--;
118+
}
119+
if (i + chunkSize > text.Length)
120+
{
121+
_factsToProcess.Add(textSpan.Slice(i).ToString());
122+
}
123+
else
124+
{
125+
int end = i + chunkSize;
126+
if (end > text.Length)
127+
{
128+
_factsToProcess.Add(textSpan.Slice(i).ToString());
129+
}
130+
else
131+
{
132+
while (end < text.Length && !char.IsWhiteSpace(textSpan[end]))
133+
{
134+
end++;
135+
}
136+
_factsToProcess.Add(textSpan.Slice(i, end - i).ToString());
137+
}
138+
}
139+
}
140+
}
141+
internal struct Fact
142+
{
143+
public Fact(string text, int id)
144+
{
145+
Text = text;
146+
Id = id;
147+
}
148+
149+
public string Text { get; set; }
150+
public int Id { get; set; }
151+
}
152+
}
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Text;
8+
using OpenAI.Chat;
9+
using static Azure.Provisioning.CloudMachine.OpenAI.EmbeddingKnowledgebase;
10+
11+
namespace Azure.Provisioning.CloudMachine.OpenAI;
12+
13+
/// <summary>
14+
/// Represents a conversation with the OpenAI chat model, incorporating a knowledgebase of embeddings data.
15+
/// </summary>
16+
public class OpenAIConversation
17+
{
18+
private readonly ChatClient _client;
19+
private readonly Prompt _prompt;
20+
private readonly Dictionary<string, ChatTool> _tools = new();
21+
private readonly EmbeddingKnowledgebase _knowledgebase;
22+
private readonly ChatCompletionOptions _options = new ChatCompletionOptions();
23+
24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="OpenAIConversation"/> class.
26+
/// </summary>
27+
/// <param name="client">The ChatClient.</param>
28+
/// <param name="tools">Any ChatTools to be used by the conversation.</param>
29+
/// <param name="knowledgebase">The knowledgebase.</param>
30+
internal OpenAIConversation(ChatClient client, IEnumerable<ChatTool> tools, EmbeddingKnowledgebase knowledgebase)
31+
{
32+
foreach (var tool in tools)
33+
{
34+
_options.Tools.Add(tool);
35+
_tools.Add(tool.FunctionName, tool);
36+
}
37+
_client = client;
38+
_knowledgebase = knowledgebase;
39+
_prompt = new Prompt();
40+
_prompt.AddTools(tools);
41+
}
42+
43+
/// <summary>
44+
/// Sends a message to the OpenAI chat model and returns the response, incorporating any relevant knowledge from the <see cref="EmbeddingKnowledgebase"/>.
45+
/// </summary>
46+
/// <param name="message"></param>
47+
/// <returns></returns>
48+
public string Say(string message)
49+
{
50+
List<Fact> facts = _knowledgebase.FindRelevantFacts(message);
51+
_prompt.AddFacts(facts);
52+
_prompt.AddUserMessage(message);
53+
var response = CallOpenAI();
54+
return response;
55+
}
56+
57+
private string CallOpenAI()
58+
{
59+
bool requiresAction;
60+
do
61+
{
62+
requiresAction = false;
63+
var completion = _client.CompleteChat(_prompt.Messages).Value;
64+
switch (completion.FinishReason)
65+
{
66+
case ChatFinishReason.ToolCalls:
67+
// TODO: Implement tool calls
68+
requiresAction = true;
69+
break;
70+
case ChatFinishReason.Length:
71+
return "Incomplete model output due to MaxTokens parameter or token limit exceeded.";
72+
case ChatFinishReason.ContentFilter:
73+
return "Omitted content due to a content filter flag.";
74+
case ChatFinishReason.Stop:
75+
_prompt.AddAssistantMessage(new AssistantChatMessage(completion));
76+
break;
77+
default:
78+
throw new NotImplementedException("Unknown finish reason.");
79+
}
80+
return _prompt.GetSayResult();
81+
} while (requiresAction);
82+
}
83+
84+
internal class Prompt
85+
{
86+
internal readonly List<UserChatMessage> userChatMessages = new();
87+
internal readonly List<SystemChatMessage> systemChatMessages = new();
88+
internal readonly List<AssistantChatMessage> assistantChatMessages = new();
89+
internal readonly List<ToolChatMessage> toolChatMessages = new();
90+
internal readonly List<ChatCompletion> chatCompletions = new();
91+
internal readonly List<ChatTool> _tools = new();
92+
internal readonly List<int> _factsAlreadyInPrompt = new List<int>();
93+
94+
public Prompt()
95+
{ }
96+
97+
public IEnumerable<ChatMessage> Messages
98+
{
99+
get
100+
{
101+
foreach (var message in systemChatMessages)
102+
{
103+
yield return message;
104+
}
105+
foreach (var message in userChatMessages)
106+
{
107+
yield return message;
108+
}
109+
foreach (var message in assistantChatMessages)
110+
{
111+
yield return message;
112+
}
113+
foreach (var message in toolChatMessages)
114+
{
115+
yield return message;
116+
}
117+
}
118+
}
119+
120+
//public ChatCompletionOptions Current => _prompt;
121+
public void AddTools(IEnumerable<ChatTool> tools)
122+
{
123+
foreach (var tool in tools)
124+
{
125+
_tools.Add(tool);
126+
}
127+
}
128+
public void AddFacts(IEnumerable<Fact> facts)
129+
{
130+
var sb = new StringBuilder();
131+
foreach (var fact in facts)
132+
{
133+
if (_factsAlreadyInPrompt.Contains(fact.Id))
134+
continue;
135+
sb.AppendLine(fact.Text);
136+
_factsAlreadyInPrompt.Add(fact.Id);
137+
}
138+
if (sb.Length > 0)
139+
{
140+
systemChatMessages.Add(ChatMessage.CreateSystemMessage(sb.ToString()));
141+
}
142+
}
143+
public void AddUserMessage(string message)
144+
{
145+
userChatMessages.Add(ChatMessage.CreateUserMessage(message));
146+
}
147+
public void AddAssistantMessage(string message)
148+
{
149+
assistantChatMessages.Add(ChatMessage.CreateAssistantMessage(message));
150+
}
151+
public void AddAssistantMessage(AssistantChatMessage message)
152+
{
153+
assistantChatMessages.Add(message);
154+
}
155+
public void AddToolMessage(ToolChatMessage message)
156+
{
157+
toolChatMessages.Add(message);
158+
}
159+
160+
internal string GetSayResult()
161+
{
162+
var result = string.Join("\n", assistantChatMessages.Select(m => m.Content[0].Text));
163+
assistantChatMessages.Clear();
164+
userChatMessages.Clear();
165+
systemChatMessages.Clear();
166+
return result;
167+
}
168+
}
169+
}

sdk/provisioning/Azure.Provisioning.CloudMachine/src/AzureSdkExtensions/OpenAIFeature.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.ClientModel;
6+
using System.Linq;
67
using Azure.AI.OpenAI;
78
using Azure.Core;
89
using Azure.Provisioning.Authorization;
@@ -113,6 +114,19 @@ public static EmbeddingClient GetOpenAIEmbeddingsClient(this ClientWorkspace wor
113114
return embeddingsClient;
114115
}
115116

117+
public static EmbeddingKnowledgebase CreateEmbeddingKnowledgebase(this ClientWorkspace workspace)
118+
{
119+
EmbeddingClient embeddingsClient = workspace.GetOpenAIEmbeddingsClient();
120+
return new EmbeddingKnowledgebase(embeddingsClient);
121+
}
122+
123+
public static OpenAIConversation CreateOpenAIConversation(this ClientWorkspace workspace)
124+
{
125+
ChatClient chatClient = workspace.GetOpenAIChatClient();
126+
EmbeddingKnowledgebase knowledgebase = workspace.CreateEmbeddingKnowledgebase();
127+
return new OpenAIConversation(chatClient, [], knowledgebase);
128+
}
129+
116130
private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace workspace)
117131
{
118132
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(AzureOpenAIClient));

sdk/provisioning/Azure.Provisioning.CloudMachine/tests/CloudMachineTests.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ public void Provisioning(string[] args)
3131
CloudMachineWorkspace cm = new();
3232
Console.WriteLine(cm.Id);
3333
var embeddings = cm.GetOpenAIEmbeddingsClient();
34+
var kb = cm.CreateEmbeddingKnowledgebase();
35+
var conversation = cm.CreateOpenAIConversation();
3436
}
3537

3638
[Ignore("no recordings yet")]

0 commit comments

Comments
 (0)