35
35
from prepdocslib .strategy import DocumentAction , SearchInfo , Strategy
36
36
from prepdocslib .textparser import TextParser
37
37
from prepdocslib .textsplitter import SentenceTextSplitter , SimpleTextSplitter
38
+ from enum import Enum
38
39
39
40
logger = logging .getLogger ("scripts" )
40
41
@@ -126,15 +127,23 @@ def setup_list_file_strategy(
126
127
return list_file_strategy
127
128
128
129
130
+ class OpenAIHost (str , Enum ):
131
+ OPENAI = "openai"
132
+ AZURE = "azure"
133
+ AZURE_CUSTOM = "azure_custom"
134
+ LOCAL = "local"
135
+
136
+
129
137
def setup_embeddings_service (
130
138
azure_credential : AsyncTokenCredential ,
131
- openai_host : str ,
132
- openai_model_name : str ,
133
- openai_service : Union [str , None ],
134
- openai_custom_url : Union [str , None ],
135
- openai_deployment : Union [str , None ],
136
- openai_dimensions : int ,
137
- openai_api_version : str ,
139
+ openai_host : OpenAIHost ,
140
+ emb_model_name : str ,
141
+ emb_model_dimensions : int ,
142
+ azure_openai_service : Union [str , None ],
143
+ azure_openai_custom_url : Union [str , None ],
144
+ azure_openai_deployment : Union [str , None ],
145
+ azure_openai_key : Union [str , None ],
146
+ azure_openai_api_version : str ,
138
147
openai_key : Union [str , None ],
139
148
openai_org : Union [str , None ],
140
149
disable_vectors : bool = False ,
@@ -144,31 +153,83 @@ def setup_embeddings_service(
144
153
logger .info ("Not setting up embeddings service" )
145
154
return None
146
155
147
- if openai_host != "openai" :
156
+ if openai_host in [ OpenAIHost . AZURE , OpenAIHost . AZURE_CUSTOM ] :
148
157
azure_open_ai_credential : Union [AsyncTokenCredential , AzureKeyCredential ] = (
149
- azure_credential if openai_key is None else AzureKeyCredential (openai_key )
158
+ azure_credential if azure_openai_key is None else AzureKeyCredential (azure_openai_key )
150
159
)
151
160
return AzureOpenAIEmbeddingService (
152
- open_ai_service = openai_service ,
153
- open_ai_custom_url = openai_custom_url ,
154
- open_ai_deployment = openai_deployment ,
155
- open_ai_model_name = openai_model_name ,
156
- open_ai_dimensions = openai_dimensions ,
157
- open_ai_api_version = openai_api_version ,
161
+ open_ai_service = azure_openai_service ,
162
+ open_ai_custom_url = azure_openai_custom_url ,
163
+ open_ai_deployment = azure_openai_deployment ,
164
+ open_ai_model_name = emb_model_name ,
165
+ open_ai_dimensions = emb_model_dimensions ,
166
+ open_ai_api_version = azure_openai_api_version ,
158
167
credential = azure_open_ai_credential ,
159
168
disable_batch = disable_batch_vectors ,
160
169
)
161
170
else :
162
171
if openai_key is None :
163
172
raise ValueError ("OpenAI key is required when using the non-Azure OpenAI API" )
164
173
return OpenAIEmbeddingService (
165
- open_ai_model_name = openai_model_name ,
166
- open_ai_dimensions = openai_dimensions ,
174
+ open_ai_model_name = emb_model_name ,
175
+ open_ai_dimensions = emb_model_dimensions ,
167
176
credential = openai_key ,
168
177
organization = openai_org ,
169
178
disable_batch = disable_batch_vectors ,
170
179
)
171
180
181
+ def setup_openai_client (
182
+ openai_host : OpenAIHost ,
183
+ azure_openai_api_key : Union [str , None ] = None ,
184
+ azure_openai_api_version : Union [str , None ] = None ,
185
+ azure_openai_service : Union [str , None ] = None ,
186
+ azure_openai_custom_url : Union [str , None ] = None ,
187
+ azure_credential : AsyncTokenCredential = None ,
188
+ openai_api_key : Union [str , None ] = None ,
189
+ openai_organization : Union [str , None ] = None ,
190
+ ):
191
+ if openai_host not in OpenAIHost :
192
+ raise ValueError (f"Invalid OPENAI_HOST value: { openai_host } . Must be one of { [h .value for h in OpenAIHost ]} ." )
193
+
194
+ if openai_host in [OpenAIHost .AZURE , OpenAIHost .AZURE_CUSTOM ]:
195
+ if openai_host == OpenAIHost .AZURE_CUSTOM :
196
+ logger .info ("OPENAI_HOST is azure_custom, setting up Azure OpenAI custom client" )
197
+ if not azure_openai_custom_url :
198
+ raise ValueError ("AZURE_OPENAI_CUSTOM_URL must be set when OPENAI_HOST is azure_custom" )
199
+ endpoint = azure_openai_custom_url
200
+ else :
201
+ logger .info ("OPENAI_HOST is azure, setting up Azure OpenAI client" )
202
+ if not azure_openai_service :
203
+ raise ValueError ("AZURE_OPENAI_SERVICE must be set when OPENAI_HOST is azure" )
204
+ endpoint = f"https://{ azure_openai_service } .openai.azure.com"
205
+ if azure_openai_api_key :
206
+ logger .info ("AZURE_OPENAI_API_KEY_OVERRIDE found, using as api_key for Azure OpenAI client" )
207
+ openai_client = AsyncAzureOpenAI (
208
+ api_version = azure_openai_api_version , azure_endpoint = endpoint , api_key = azure_openai_api_key
209
+ )
210
+ else :
211
+ logger .info ("Using Azure credential (passwordless authentication) for Azure OpenAI client" )
212
+ token_provider = get_bearer_token_provider (azure_credential , "https://cognitiveservices.azure.com/.default" )
213
+ openai_client = AsyncAzureOpenAI (
214
+ api_version = azure_openai_api_version ,
215
+ azure_endpoint = endpoint ,
216
+ azure_ad_token_provider = token_provider ,
217
+ )
218
+ elif openai_host == OpenAIHost .LOCAL :
219
+ logger .info ("OPENAI_HOST is local, setting up local OpenAI client for OPENAI_BASE_URL with no key" )
220
+ openai_client = AsyncOpenAI (
221
+ base_url = os .environ ["OPENAI_BASE_URL" ],
222
+ api_key = "no-key-required" ,
223
+ )
224
+ else :
225
+ logger .info (
226
+ "OPENAI_HOST is not azure, setting up OpenAI client using OPENAI_API_KEY and OPENAI_ORGANIZATION environment variables"
227
+ )
228
+ openai_client = AsyncOpenAI (
229
+ api_key = openai_api_key ,
230
+ organization = openai_organization ,
231
+ )
232
+ return openai_client
172
233
173
234
def setup_file_processors (
174
235
azure_credential : AsyncTokenCredential ,
@@ -194,7 +255,7 @@ def setup_file_processors(
194
255
doc_int_parser = DocumentAnalysisParser (
195
256
endpoint = f"https://{ document_intelligence_service } .cognitiveservices.azure.com/" ,
196
257
credential = documentintelligence_creds ,
197
- media_description_strategy = "openai" if use_multimodal else "contentunderstanding" if use_content_understanding else "none" ,
258
+ media_description_strategy = MediaDescriptionStrategy . OPENAI if use_multimodal else MediaDescriptionStrategy . CONTENTUNDERSTANDING if use_content_understanding else MediaDescriptionStrategy . NONE ,
198
259
openai_client = openai_client ,
199
260
openai_model = openai_model ,
200
261
openai_deployment = openai_deployment ,
@@ -323,7 +384,7 @@ async def main(strategy: Strategy, setup_index: bool = True):
323
384
args = parser .parse_args ()
324
385
325
386
if args .verbose :
326
- logging .basicConfig (format = "%(message)s" , datefmt = "[%X]" , handlers = [RichHandler (rich_tracebacks = True )])
387
+ logging .basicConfig (format = "%(message)s" , datefmt = "[%X]" , handlers = [RichHandler (rich_tracebacks = True )], level = logging . WARNING )
327
388
# We only set the level to INFO for our logger,
328
389
# to avoid seeing the noisy INFO level logs from the Azure SDKs
329
390
logger .setLevel (logging .DEBUG )
@@ -397,31 +458,38 @@ async def main(strategy: Strategy, setup_index: bool = True):
397
458
datalake_key = clean_key_if_exists (args .datalakekey ),
398
459
)
399
460
400
- openai_host = os .environ ["OPENAI_HOST" ]
401
- openai_key = None
402
- if os .getenv ("AZURE_OPENAI_API_KEY_OVERRIDE" ):
403
- openai_key = os .getenv ("AZURE_OPENAI_API_KEY_OVERRIDE" )
404
- elif not openai_host .startswith ("azure" ) and os .getenv ("OPENAI_API_KEY" ):
405
- openai_key = os .getenv ("OPENAI_API_KEY" )
406
-
407
- openai_dimensions = 1536
461
+ openai_host = OpenAIHost (os .environ ["OPENAI_HOST" ])
462
+ # https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
463
+ azure_openai_api_version = os .getenv ("AZURE_OPENAI_API_VERSION" ) or "2024-06-01"
464
+ emb_model_dimensions = 1536
408
465
if os .getenv ("AZURE_OPENAI_EMB_DIMENSIONS" ):
409
- openai_dimensions = int (os .environ ["AZURE_OPENAI_EMB_DIMENSIONS" ])
466
+ emb_model_dimensions = int (os .environ ["AZURE_OPENAI_EMB_DIMENSIONS" ])
410
467
openai_embeddings_service = setup_embeddings_service (
411
468
azure_credential = azd_credential ,
412
469
openai_host = openai_host ,
413
- openai_model_name = os .environ ["AZURE_OPENAI_EMB_MODEL_NAME" ],
414
- openai_service = os . getenv ( "AZURE_OPENAI_SERVICE" ) ,
415
- openai_custom_url = os .getenv ("AZURE_OPENAI_CUSTOM_URL " ),
416
- openai_deployment = os .getenv ("AZURE_OPENAI_EMB_DEPLOYMENT " ),
417
- # https://learn.microsoft.com/azure/ai-services/openai/api-version-deprecation#latest-ga-api-release
418
- openai_api_version = os . getenv ( "AZURE_OPENAI_API_VERSION" ) or "2024-06-01" ,
419
- openai_dimensions = openai_dimensions ,
420
- openai_key = clean_key_if_exists (openai_key ),
470
+ emb_model_name = os .environ ["AZURE_OPENAI_EMB_MODEL_NAME" ],
471
+ emb_model_dimensions = emb_model_dimensions ,
472
+ azure_openai_service = os .getenv ("AZURE_OPENAI_SERVICE " ),
473
+ azure_openai_custom_url = os .getenv ("AZURE_OPENAI_CUSTOM_URL " ),
474
+ azure_openai_deployment = os . getenv ( "AZURE_OPENAI_EMB_DEPLOYMENT" ),
475
+ azure_openai_api_version = azure_openai_api_version ,
476
+ azure_openai_key = os . getenv ( "AZURE_OPENAI_API_KEY_OVERRIDE" ) ,
477
+ openai_key = clean_key_if_exists (os . getenv ( "OPENAI_API_KEY" ) ),
421
478
openai_org = os .getenv ("OPENAI_ORGANIZATION" ),
422
479
disable_vectors = dont_use_vectors ,
423
480
disable_batch_vectors = args .disablebatchvectors ,
424
481
)
482
+ openai_client = setup_openai_client (
483
+ openai_host = openai_host ,
484
+ azure_openai_api_version = azure_openai_api_version ,
485
+ azure_openai_service = os .getenv ("AZURE_OPENAI_SERVICE" ),
486
+ azure_openai_custom_url = os .getenv ("AZURE_OPENAI_CUSTOM_URL" ),
487
+ azure_openai_api_key = os .getenv ("AZURE_OPENAI_API_KEY_OVERRIDE" ),
488
+ azure_credential = azd_credential ,
489
+ openai_api_key = clean_key_if_exists (os .getenv ("OPENAI_API_KEY" )),
490
+ openai_organization = os .getenv ("OPENAI_ORGANIZATION" ),
491
+ )
492
+
425
493
426
494
ingestion_strategy : Strategy
427
495
if use_int_vectorization :
@@ -452,6 +520,9 @@ async def main(strategy: Strategy, setup_index: bool = True):
452
520
use_content_understanding = use_content_understanding ,
453
521
use_multimodal = use_multimodal ,
454
522
content_understanding_endpoint = os .getenv ("AZURE_CONTENTUNDERSTANDING_ENDPOINT" ),
523
+ openai_client = openai_client ,
524
+ openai_model = os .getenv ("AZURE_OPENAI_CHATGPT_MODEL" ),
525
+ openai_deployment = os .getenv ("AZURE_OPENAI_CHATGPT_DEPLOYMENT" ) if openai_host == OpenAIHost .AZURE else None ,
455
526
)
456
527
457
528
image_embeddings_service = setup_image_embeddings_service (
0 commit comments