8
8
import time
9
9
10
10
import openai
11
+ import tiktoken
11
12
from azure .ai .formrecognizer import DocumentAnalysisClient
12
13
from azure .core .credentials import AzureKeyCredential
13
14
from azure .identity import AzureDeveloperCliCredential
40
41
CACHE_KEY_CREATED_TIME = 'created_time'
41
42
CACHE_KEY_TOKEN_TYPE = 'token_type'
42
43
44
+ #Embedding batch support section
45
+ SUPPORTED_BATCH_AOAI_MODEL = {
46
+ 'text-embedding-ada-002' : {
47
+ 'token_limit' : 8100 ,
48
+ 'max_batch_size' : 16
49
+ }
50
+ }
51
+
52
+ def calculate_tokens_emb_aoai (input : str ):
53
+ encoding = tiktoken .encoding_for_model (args .openaimodelname )
54
+ return len (encoding .encode (input ))
55
+
43
56
def blob_name_from_file_page (filename , page = 0 ):
44
57
if os .path .splitext (filename )[1 ].lower () == ".pdf" :
45
58
return os .path .splitext (os .path .basename (filename ))[0 ] + f"-{ page } " + ".pdf"
@@ -229,11 +242,17 @@ def create_sections(filename, page_map, use_vectors):
229
242
def before_retry_sleep (retry_state ):
230
243
if args .verbose : print ("Rate limited on the OpenAI embeddings API, sleeping before retrying..." )
231
244
232
- @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (15 ), before_sleep = before_retry_sleep )
245
+ @retry (wait = wait_random_exponential (min = 15 , max = 60 ), stop = stop_after_attempt (15 ), before_sleep = before_retry_sleep )
233
246
def compute_embedding (text ):
234
247
refresh_openai_token ()
235
248
return openai .Embedding .create (engine = args .openaideployment , input = text )["data" ][0 ]["embedding" ]
236
249
250
+ @retry (wait = wait_random_exponential (min = 15 , max = 60 ), stop = stop_after_attempt (15 ), before_sleep = before_retry_sleep )
251
+ def compute_embedding_in_batch (texts ):
252
+ refresh_openai_token ()
253
+ emb_response = openai .Embedding .create (engine = args .openaideployment , input = texts )
254
+ return [data .embedding for data in emb_response .data ]
255
+
237
256
def create_search_index ():
238
257
if args .verbose : print (f"Ensuring search index { args .index } exists" )
239
258
index_client = SearchIndexClient (endpoint = f"https://{ args .searchservice } .search.windows.net/" ,
@@ -271,6 +290,35 @@ def create_search_index():
271
290
else :
272
291
if args .verbose : print (f"Search index { args .index } already exists" )
273
292
293
+ def update_embeddings_in_batch (sections ):
294
+ batch_queue = []
295
+ copy_s = []
296
+ batch_response = {}
297
+ token_count = 0
298
+ for s in sections :
299
+ token_count += calculate_tokens_emb_aoai (s ["content" ])
300
+ if token_count <= SUPPORTED_BATCH_AOAI_MODEL [args .openaimodelname ]['token_limit' ] and len (batch_queue ) < SUPPORTED_BATCH_AOAI_MODEL [args .openaimodelname ]['max_batch_size' ]:
301
+ batch_queue .append (s )
302
+ copy_s .append (s )
303
+ else :
304
+ emb_responses = compute_embedding_in_batch ([item ["content" ] for item in batch_queue ])
305
+ if args .verbose : print (f"Batch Completed. Batch size { len (batch_queue )} Token count { token_count } " )
306
+ for emb , item in zip (emb_responses , batch_queue ):
307
+ batch_response [item ["id" ]] = emb
308
+ batch_queue = []
309
+ batch_queue .append (s )
310
+ token_count = calculate_tokens_emb_aoai (s ["content" ])
311
+
312
+ if batch_queue :
313
+ emb_responses = compute_embedding_in_batch ([item ["content" ] for item in batch_queue ])
314
+ if args .verbose : print (f"Batch Completed. Batch size { len (batch_queue )} Token count { token_count } " )
315
+ for emb , item in zip (emb_responses , batch_queue ):
316
+ batch_response [item ["id" ]] = emb
317
+
318
+ for s in copy_s :
319
+ s ["embedding" ] = batch_response [s ["id" ]]
320
+ yield s
321
+
274
322
def index_sections (filename , sections ):
275
323
if args .verbose : print (f"Indexing sections from '{ filename } ' into search index '{ args .index } '" )
276
324
search_client = SearchClient (endpoint = f"https://{ args .searchservice } .search.windows.net/" ,
@@ -314,7 +362,7 @@ def refresh_openai_token():
314
362
openai .api_key = token_cred .get_token ("https://cognitiveservices.azure.com/.default" ).token
315
363
open_ai_token_cache [CACHE_KEY_CREATED_TIME ] = time .time ()
316
364
317
- def read_files (path_pattern : str , use_vectors : bool ):
365
+ def read_files (path_pattern : str , use_vectors : bool , vectors_batch_support : bool ):
318
366
"""
319
367
Recursively read directory structure under `path_pattern`
320
368
and execute indexing for the individual files
@@ -326,13 +374,16 @@ def read_files(path_pattern: str, use_vectors: bool):
326
374
remove_from_index (filename )
327
375
else :
328
376
if os .path .isdir (filename ):
329
- read_files (filename + "/*" , use_vectors )
377
+ read_files (filename + "/*" , use_vectors , vectors_batch_support )
330
378
continue
331
379
try :
332
380
if not args .skipblobs :
333
381
upload_blobs (filename )
334
382
page_map = get_document_text (filename )
335
- sections = create_sections (os .path .basename (filename ), page_map , use_vectors )
383
+ sections = create_sections (os .path .basename (filename ), page_map , use_vectors and not vectors_batch_support )
384
+ print (use_vectors and vectors_batch_support )
385
+ if use_vectors and vectors_batch_support :
386
+ sections = update_embeddings_in_batch (sections )
336
387
index_sections (os .path .basename (filename ), sections )
337
388
except Exception as e :
338
389
print (f"\t Got an error while reading { filename } -> { e } --> skipping file" )
@@ -355,7 +406,9 @@ def read_files(path_pattern: str, use_vectors: bool):
355
406
parser .add_argument ("--searchkey" , required = False , help = "Optional. Use this Azure Cognitive Search account key instead of the current user identity to login (use az login to set current user for Azure)" )
356
407
parser .add_argument ("--openaiservice" , help = "Name of the Azure OpenAI service used to compute embeddings" )
357
408
parser .add_argument ("--openaideployment" , help = "Name of the Azure OpenAI model deployment for an embedding model ('text-embedding-ada-002' recommended)" )
409
+ parser .add_argument ("--openaimodelname" , help = "Name of the Azure OpenAI embedding model ('text-embedding-ada-002' recommended)" )
358
410
parser .add_argument ("--novectors" , action = "store_true" , help = "Don't compute embeddings for the sections (e.g. don't call the OpenAI embeddings API during indexing)" )
411
+ parser .add_argument ("--disablebatchvectors" , action = "store_true" , help = "Don't compute embeddings in batch for the sections" )
359
412
parser .add_argument ("--openaikey" , required = False , help = "Optional. Use this Azure OpenAI account key instead of the current user identity to login (use az login to set current user for Azure)" )
360
413
parser .add_argument ("--remove" , action = "store_true" , help = "Remove references to this document from blob storage and the search index" )
361
414
parser .add_argument ("--removeall" , action = "store_true" , help = "Remove all blobs from blob storage and documents from the search index" )
@@ -370,6 +423,7 @@ def read_files(path_pattern: str, use_vectors: bool):
370
423
default_creds = azd_credential if args .searchkey is None or args .storagekey is None else None
371
424
search_creds = default_creds if args .searchkey is None else AzureKeyCredential (args .searchkey )
372
425
use_vectors = not args .novectors
426
+ compute_vectors_in_batch = not args .disablebatchvectors and args .openaimodelname in SUPPORTED_BATCH_AOAI_MODEL
373
427
374
428
if not args .skipblobs :
375
429
storage_creds = default_creds if args .storagekey is None else args .storagekey
@@ -402,4 +456,4 @@ def read_files(path_pattern: str, use_vectors: bool):
402
456
create_search_index ()
403
457
404
458
print ("Processing files..." )
405
- read_files (args .files , use_vectors )
459
+ read_files (args .files , use_vectors , compute_vectors_in_batch )
0 commit comments