Skip to content

Commit eb65d06

Browse files
Use AzureEmbedFunctionService from EmbedFunction in PrepDoc (#211)
## Purpose <!-- Describe the intention of the changes being proposed. What problem does it solve or functionality does it add? --> * ... ## Does this introduce a breaking change? <!-- Mark one with an "x". --> ``` [ ] Yes [ ] No ``` ## Pull Request Type What kind of change does this Pull Request introduce? <!-- Please check the one that applies to this PR using "x". --> ``` [ ] Bugfix [ ] Feature [ ] Code style update (formatting, local variables) [ ] Refactoring (no functional changes, no api changes) [ ] Documentation content changes [ ] Other... Please describe: ``` ## How to Test * Get the code ``` git clone [repo-address] cd [repo-name] git checkout [branch-name] npm install ``` * Test the code <!-- Add steps to run the tests suite and/or manually test --> ``` ``` ## What to Check Verify that the following are valid * ... ## Other Information <!-- Add any other helpful information that may be needed here. -->
1 parent a47ced9 commit eb65d06

File tree

10 files changed

+138
-480
lines changed

10 files changed

+138
-480
lines changed

app/functions/EmbedFunctions/EmbedFunctions.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
<PackageReference Include="Azure.AI.FormRecognizer" />
1313
<PackageReference Include="Azure.Search.Documents" />
1414
<PackageReference Include="Azure.Storage.Blobs" />
15+
<PackageReference Include="Azure.AI.OpenAI" />
1516
<PackageReference Include="Azure.Storage.Files.Shares" />
1617
<PackageReference Include="Azure.Storage.Queues" />
1718
<PackageReference Include="Microsoft.Azure.Functions.Worker.Extensions.Storage" />

app/functions/EmbedFunctions/EmbeddingFunction.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ public sealed class EmbeddingFunction(
1212
public Task EmbedAsync(
1313
[BlobTrigger(
1414
blobPath: "content/{name}",
15-
Connection = "AzureStorageAccountEndpoint")] Stream blobStream,
16-
string name,
17-
BlobClient client) => embeddingAggregateService.EmbedBlobAsync(client, blobStream, blobName: name);
15+
Connection = "AzureWebJobsStorage")] Stream blobStream,
16+
string name) => embeddingAggregateService.EmbedBlobAsync(blobStream, blobName: name);
1817
}

app/functions/EmbedFunctions/Program.cs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using Azure.AI.OpenAI;
4+
using Microsoft.Extensions.DependencyInjection;
5+
36
var host = new HostBuilder()
47
.ConfigureServices(services =>
58
{
@@ -36,7 +39,7 @@ uri is not null
3639
services.AddSingleton<BlobContainerClient>(_ =>
3740
{
3841
var blobServiceClient = new BlobServiceClient(
39-
GetUriFromEnvironment("AZURE_STORAGE_ACCOUNT_ENDPOINT"),
42+
GetUriFromEnvironment("AZURE_STORAGE_BLOB_ENDPOINT"),
4043
credential);
4144

4245
return blobServiceClient.GetBlobContainerClient("corpus");
@@ -45,10 +48,22 @@ uri is not null
4548
services.AddSingleton<EmbedServiceFactory>();
4649
services.AddSingleton<EmbeddingAggregateService>();
4750

48-
services.AddSingleton<IEmbedService, AzureSearchEmbedService>();
49-
services.AddSingleton<IEmbedService, PineconeEmbedService>();
50-
services.AddSingleton<IEmbedService, QdrantEmbedService>();
51-
services.AddSingleton<IEmbedService, MilvusEmbedService>();
51+
services.AddSingleton<IEmbedService, AzureSearchEmbedService>(provider =>
52+
{
53+
var searchIndexName = Environment.GetEnvironmentVariable("AZURE_SEARCH_INDEX") ?? throw new ArgumentNullException("AZURE_SEARCH_INDEX is null");
54+
var embeddingModelName = Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") ?? throw new ArgumentNullException("AZURE_OPENAI_EMBEDDING_DEPLOYMENT is null");
55+
var openaiEndPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentNullException("AZURE_OPENAI_ENDPOINT is null");
56+
57+
var openAIClient = new OpenAIClient(new Uri(openaiEndPoint), new DefaultAzureCredential());
58+
59+
var searchClient = provider.GetRequiredService<SearchClient>();
60+
var searchIndexClient = provider.GetRequiredService<SearchIndexClient>();
61+
var blobContainerClient = provider.GetRequiredService<BlobContainerClient>();
62+
var documentClient = provider.GetRequiredService<DocumentAnalysisClient>();
63+
var logger = provider.GetRequiredService<ILogger<AzureSearchEmbedService>>();
64+
65+
return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, logger);
66+
});
5267
})
5368
.ConfigureFunctionsWorkerDefaults()
5469
.Build();

app/functions/EmbedFunctions/Services/AzureSearchEmbedService.cs

Lines changed: 80 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using Azure.AI.OpenAI;
4+
using Google.Protobuf.WellKnownTypes;
5+
using Microsoft.Extensions.Options;
6+
37
namespace EmbedFunctions.Services;
48

5-
internal sealed partial class AzureSearchEmbedService(
9+
public sealed partial class AzureSearchEmbedService(
10+
OpenAIClient openAIClient,
11+
string embeddingModelName,
612
SearchClient indexSectionClient,
13+
string searchIndexName,
714
SearchIndexClient searchIndexClient,
815
DocumentAnalysisClient documentAnalysisClient,
916
BlobContainerClient corpusContainerClient,
10-
ILogger<AzureSearchEmbedService> logger) : IEmbedService
17+
ILogger<AzureSearchEmbedService>? logger) : IEmbedService
1118
{
1219
[GeneratedRegex("[^0-9a-zA-Z_-]")]
1320
private static partial Regex MatchInSetRegex();
@@ -16,9 +23,6 @@ public async Task<bool> EmbedBlobAsync(Stream blobStream, string blobName)
1623
{
1724
try
1825
{
19-
var searchIndexName = Environment.GetEnvironmentVariable(
20-
"AZURE_SEARCH_INDEX") ?? "gptkbindex";
21-
2226
await EnsureSearchIndexAsync(searchIndexName);
2327

2428
var pageMap = await GetDocumentTextAsync(blobStream, blobName);
@@ -41,67 +45,94 @@ public async Task<bool> EmbedBlobAsync(Stream blobStream, string blobName)
4145
}
4246
catch (Exception exception)
4347
{
44-
logger.LogError(
48+
logger?.LogError(
4549
exception, "Failed to embed blob '{BlobName}'", blobName);
4650

4751
return false;
4852
}
4953
}
5054

51-
private async Task EnsureSearchIndexAsync(string searchIndexName)
55+
public async Task CreateSearchIndexAsync(string searchIndexName)
5256
{
53-
var indexNames = searchIndexClient.GetIndexNamesAsync();
54-
await foreach (var page in indexNames.AsPages())
57+
string vectorSearchConfigName = "my-vector-config";
58+
string vectorSearchProfile = "my-vector-profile";
59+
var index = new SearchIndex(searchIndexName)
5560
{
56-
if (page.Values.Any(indexName => indexName == searchIndexName))
61+
VectorSearch = new()
5762
{
58-
logger.LogWarning(
59-
"Search index '{SearchIndexName}' already exists", searchIndexName);
60-
return;
63+
Algorithms =
64+
{
65+
new HnswVectorSearchAlgorithmConfiguration(vectorSearchConfigName)
66+
},
67+
Profiles =
68+
{
69+
new VectorSearchProfile(vectorSearchProfile, vectorSearchConfigName)
6170
}
62-
}
63-
64-
var index = new SearchIndex(searchIndexName)
65-
{
71+
},
6672
Fields =
73+
{
74+
new SimpleField("id", SearchFieldDataType.String) { IsKey = true },
75+
new SearchableField("content") { AnalyzerName = LexicalAnalyzerName.EnMicrosoft },
76+
new SimpleField("category", SearchFieldDataType.String) { IsFacetable = true },
77+
new SimpleField("sourcepage", SearchFieldDataType.String) { IsFacetable = true },
78+
new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true },
79+
new SearchField("embedding", SearchFieldDataType.Collection(SearchFieldDataType.Single))
6780
{
68-
new SimpleField("id", SearchFieldDataType.String) { IsKey = true },
69-
new SearchableField("content") { AnalyzerName = "en.microsoft" },
70-
new SimpleField("category", SearchFieldDataType.String) { IsFacetable = true },
71-
new SimpleField("sourcepage", SearchFieldDataType.String) { IsFacetable = true },
72-
new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true }
73-
},
81+
VectorSearchDimensions = 1536,
82+
IsSearchable = true,
83+
VectorSearchProfile = vectorSearchProfile,
84+
}
85+
},
7486
SemanticSettings = new SemanticSettings
7587
{
7688
Configurations =
89+
{
90+
new SemanticConfiguration("default", new PrioritizedFields
7791
{
78-
new SemanticConfiguration("default", new PrioritizedFields
92+
ContentFields =
7993
{
80-
ContentFields =
94+
new SemanticField
8195
{
82-
new SemanticField
83-
{
84-
FieldName = "content"
85-
}
96+
FieldName = "content"
8697
}
87-
})
8898
}
99+
})
100+
}
89101
}
90102
};
91103

92-
logger.LogInformation(
93-
"Creating '{searchIndexName}' search index", searchIndexName);
104+
logger?.LogInformation(
105+
"Creating '{searchIndexName}' search index", searchIndexName);
94106

95107
await searchIndexClient.CreateIndexAsync(index);
96108
}
97109

110+
public async Task EnsureSearchIndexAsync(string searchIndexName)
111+
{
112+
var indexNames = searchIndexClient.GetIndexNamesAsync();
113+
await foreach (var page in indexNames.AsPages())
114+
{
115+
if (page.Values.Any(indexName => indexName == searchIndexName))
116+
{
117+
logger?.LogWarning(
118+
"Search index '{SearchIndexName}' already exists", searchIndexName);
119+
return;
120+
}
121+
}
122+
123+
await CreateSearchIndexAsync(searchIndexName);
124+
}
125+
98126
private async Task<IReadOnlyList<PageDetail>> GetDocumentTextAsync(Stream blobStream, string blobName)
99127
{
100-
logger.LogInformation(
128+
logger?.LogInformation(
101129
"Extracting text from '{Blob}' using Azure Form Recognizer", blobName);
102130

131+
using var ms = new MemoryStream();
132+
blobStream.CopyTo(ms);
133+
ms.Position = 0;
103134
AnalyzeDocumentOperation operation = documentAnalysisClient.AnalyzeDocument(
104-
WaitUntil.Started, "prebuilt-layout", blobStream);
135+
WaitUntil.Started, "prebuilt-layout", ms);
105136

106137
var offset = 0;
107138
List<PageDetail> pageMap = [];
@@ -208,7 +239,7 @@ private async Task UploadCorpusAsync(string corpusBlobName, string text)
208239
return;
209240
}
210241

211-
logger.LogInformation("Uploading corpus '{CorpusBlobName}'", corpusBlobName);
242+
logger?.LogInformation("Uploading corpus '{CorpusBlobName}'", corpusBlobName);
212243

213244
await using var stream = new MemoryStream(Encoding.UTF8.GetBytes(text));
214245
await blobClient.UploadAsync(stream, new BlobHttpHeaders
@@ -231,7 +262,7 @@ private IEnumerable<Section> CreateSections(
231262
var start = 0;
232263
var end = length;
233264

234-
logger.LogInformation("Splitting '{BlobName}' into sections", blobName);
265+
logger?.LogInformation("Splitting '{BlobName}' into sections", blobName);
235266

236267
while (start + SectionOverlap < length)
237268
{
@@ -300,9 +331,9 @@ private IEnumerable<Section> CreateSections(
300331
// If the section ends with an unclosed table, we need to start the next section with the table.
301332
// If table starts inside SentenceSearchLimit, we ignore it, as that will cause an infinite loop for tables longer than MaxSectionLength
302333
// If last table starts inside SectionOverlap, keep overlapping
303-
if (logger.IsEnabled(LogLevel.Warning))
334+
if (logger?.IsEnabled(LogLevel.Warning) is true)
304335
{
305-
logger.LogWarning("""
336+
logger?.LogWarning("""
306337
Section ends with unclosed table, starting next section with the
307338
table at page {Offset} offset {Start} table start {LastTableStart}
308339
""",
@@ -349,10 +380,10 @@ private static string BlobNameFromFilePage(string blobName, int page = 0) => Pat
349380

350381
private async Task IndexSectionsAsync(string searchIndexName, IEnumerable<Section> sections, string blobName)
351382
{
352-
var infoLoggingEnabled = logger.IsEnabled(LogLevel.Information);
353-
if (infoLoggingEnabled)
383+
var infoLoggingEnabled = logger?.IsEnabled(LogLevel.Information);
384+
if (infoLoggingEnabled is true)
354385
{
355-
logger.LogInformation("""
386+
logger?.LogInformation("""
356387
Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
357388
""",
358389
blobName,
@@ -363,6 +394,8 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
363394
var batch = new IndexDocumentsBatch<SearchDocument>();
364395
foreach (var section in sections)
365396
{
397+
var embeddings = await openAIClient.GetEmbeddingsAsync(embeddingModelName, new Azure.AI.OpenAI.EmbeddingsOptions(section.Content.Replace('\r', ' ')));
398+
var embedding = embeddings.Value.Data.FirstOrDefault()?.Embedding.ToArray() ?? [];
366399
batch.Actions.Add(new IndexDocumentsAction<SearchDocument>(
367400
IndexActionType.MergeOrUpload,
368401
new SearchDocument
@@ -371,7 +404,8 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
371404
["content"] = section.Content,
372405
["category"] = section.Category,
373406
["sourcepage"] = section.SourcePage,
374-
["sourcefile"] = section.SourceFile
407+
["sourcefile"] = section.SourceFile,
408+
["embedding"] = embedding,
375409
}));
376410

377411
iteration++;
@@ -380,9 +414,9 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
380414
// Every one thousand documents, batch create.
381415
IndexDocumentsResult result = await indexSectionClient.IndexDocumentsAsync(batch);
382416
int succeeded = result.Results.Count(r => r.Succeeded);
383-
if (infoLoggingEnabled)
417+
if (infoLoggingEnabled is true)
384418
{
385-
logger.LogInformation("""
419+
logger?.LogInformation("""
386420
Indexed {Count} sections, {Succeeded} succeeded
387421
""",
388422
batch.Actions.Count,
@@ -399,9 +433,9 @@ Indexing sections from '{BlobName}' into search index '{SearchIndexName}'
399433
var index = new SearchIndex($"index-{batch.Actions.Count}");
400434
IndexDocumentsResult result = await indexSectionClient.IndexDocumentsAsync(batch);
401435
int succeeded = result.Results.Count(r => r.Succeeded);
402-
if (logger.IsEnabled(LogLevel.Information))
436+
if (logger?.IsEnabled(LogLevel.Information) is true)
403437
{
404-
logger.LogInformation("""
438+
logger?.LogInformation("""
405439
Indexed {Count} sections, {Succeeded} succeeded
406440
""",
407441
batch.Actions.Count,

app/functions/EmbedFunctions/Services/EmbeddingAggregateService.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System.IO;
4+
35
namespace EmbedFunctions.Services;
46

57
public sealed class EmbeddingAggregateService(
68
EmbedServiceFactory embedServiceFactory,
9+
BlobContainerClient client,
710
ILogger<EmbeddingAggregateService> logger)
811
{
9-
internal async Task EmbedBlobAsync(BlobClient client, Stream blobStream, string blobName)
12+
internal async Task EmbedBlobAsync(Stream blobStream, string blobName)
1013
{
1114
try
1215
{

app/prepdocs/PrepareDocs/PageDetail.cs

Lines changed: 0 additions & 6 deletions
This file was deleted.

app/prepdocs/PrepareDocs/PrepareDocs.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,9 @@
1818
<PackageReference Include="PdfSharpCore" />
1919
<PackageReference Include="System.CommandLine" />
2020
</ItemGroup>
21+
22+
<ItemGroup>
23+
<ProjectReference Include="..\..\functions\EmbedFunctions\EmbedFunctions.csproj" />
24+
</ItemGroup>
2125

2226
</Project>

app/prepdocs/PrepareDocs/Program.Clients.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33

4+
using EmbedFunctions.Services;
5+
using Microsoft.Extensions.Logging;
6+
47
internal static partial class Program
58
{
69
private static BlobContainerClient? s_corpusContainerClient;
@@ -16,6 +19,21 @@ internal static partial class Program
1619
private static readonly SemaphoreSlim s_searchIndexLock = new(1);
1720
private static readonly SemaphoreSlim s_searchLock = new(1);
1821
private static readonly SemaphoreSlim s_openAILock = new(1);
22+
private static readonly SemaphoreSlim s_embeddingLock = new(1);
23+
24+
private static Task<AzureSearchEmbedService> GetAzureSearchEmbedService(AppOptions options) =>
25+
GetLazyClientAsync<AzureSearchEmbedService>(options, s_embeddingLock, async o =>
26+
{
27+
var searchIndexClient = await GetSearchIndexClientAsync(o);
28+
var searchClient = await GetSearchClientAsync(o);
29+
var documentClient = await GetFormRecognizerClientAsync(o);
30+
var blobContainerClient = await GetBlobContainerClientAsync(o);
31+
var openAIClient = await GetAzureOpenAIClientAsync(o);
32+
var embeddingModelName = o.EmbeddingModelName ?? throw new ArgumentNullException(nameof(o.EmbeddingModelName));
33+
var searchIndexName = o.SearchIndexName ?? throw new ArgumentNullException(nameof(o.SearchIndexName));
34+
35+
return new AzureSearchEmbedService(openAIClient, embeddingModelName, searchClient, searchIndexName, searchIndexClient, documentClient, blobContainerClient, null);
36+
});
1937

2038
private static Task<BlobContainerClient> GetCorpusBlobContainerClientAsync(AppOptions options) =>
2139
GetLazyClientAsync<BlobContainerClient>(options, s_corpusContainerLock, static async o =>

0 commit comments

Comments
 (0)