@@ -169,23 +169,19 @@ def setup_embeddings_service(
169169 logger .info ("Not setting up embeddings service" )
170170 return None
171171
172- azure_endpoint = None
173- azure_deployment = None
174172 if openai_host in [OpenAIHost .AZURE , OpenAIHost .AZURE_CUSTOM ]:
175173 if azure_openai_endpoint is None :
176174 raise ValueError ("Azure OpenAI endpoint must be provided when using Azure OpenAI embeddings" )
177175 if azure_openai_deployment is None :
178176 raise ValueError ("Azure OpenAI deployment must be provided when using Azure OpenAI embeddings" )
179- azure_endpoint = azure_openai_endpoint
180- azure_deployment = azure_openai_deployment
181177
182178 return OpenAIEmbeddings (
183179 open_ai_client = open_ai_client ,
184180 open_ai_model_name = emb_model_name ,
185181 open_ai_dimensions = emb_model_dimensions ,
186182 disable_batch = disable_batch_vectors ,
187- azure_deployment_name = azure_deployment ,
188- azure_endpoint = azure_endpoint ,
183+ azure_deployment_name = azure_openai_deployment ,
184+ azure_endpoint = azure_openai_endpoint ,
189185 )
190186
191187
@@ -197,33 +193,39 @@ def setup_openai_client(
197193 azure_openai_custom_url : Optional [str ] = None ,
198194 openai_api_key : Optional [str ] = None ,
199195 openai_organization : Optional [str ] = None ,
200- ):
196+ ) -> tuple [ AsyncOpenAI , Optional [ str ]] :
201197 if openai_host not in OpenAIHost :
202198 raise ValueError (f"Invalid OPENAI_HOST value: { openai_host } . Must be one of { [h .value for h in OpenAIHost ]} ." )
203199
204200 openai_client : AsyncOpenAI
201+ azure_openai_endpoint : Optional [str ] = None
205202
206203 if openai_host in [OpenAIHost .AZURE , OpenAIHost .AZURE_CUSTOM ]:
204+ base_url : Optional [str ] = None
205+ api_key_or_token : Optional [str | AsyncTokenCredential ] = None
207206 if openai_host == OpenAIHost .AZURE_CUSTOM :
208207 logger .info ("OPENAI_HOST is azure_custom, setting up Azure OpenAI custom client" )
209208 if not azure_openai_custom_url :
210209 raise ValueError ("AZURE_OPENAI_CUSTOM_URL must be set when OPENAI_HOST is azure_custom" )
211- endpoint = azure_openai_custom_url
210+ base_url = azure_openai_custom_url
212211 else :
213212 logger .info ("OPENAI_HOST is azure, setting up Azure OpenAI client" )
214213 if not azure_openai_service :
215214 raise ValueError ("AZURE_OPENAI_SERVICE must be set when OPENAI_HOST is azure" )
216- endpoint = f"https://{ azure_openai_service } .openai.azure.com/openai/v1"
215+ azure_openai_endpoint = "https://{azure_openai_service}.openai.azure.com/"
216+ base_url = f"{ azure_openai_endpoint } /openai/v1"
217217 if azure_openai_api_key :
218218 logger .info ("AZURE_OPENAI_API_KEY_OVERRIDE found, using as api_key for Azure OpenAI client" )
219- openai_client = AsyncOpenAI ( base_url = endpoint , api_key = azure_openai_api_key )
219+ api_key_or_token = azure_openai_api_key
220220 else :
221221 logger .info ("Using Azure credential (passwordless authentication) for Azure OpenAI client" )
222- token_provider = get_bearer_token_provider (azure_credential , "https://cognitiveservices.azure.com/.default" )
223- openai_client = AsyncOpenAI (
224- base_url = endpoint ,
225- api_key = token_provider ,
222+ api_key_or_token = get_bearer_token_provider (
223+ azure_credential , "https://cognitiveservices.azure.com/.default"
226224 )
225+ openai_client = AsyncOpenAI (
226+ base_url = base_url ,
227+ api_key = api_key_or_token ,
228+ )
227229 elif openai_host == OpenAIHost .LOCAL :
228230 logger .info ("OPENAI_HOST is local, setting up local OpenAI client for OPENAI_BASE_URL with no key" )
229231 openai_client = AsyncOpenAI (
@@ -240,7 +242,7 @@ def setup_openai_client(
240242 api_key = openai_api_key ,
241243 organization = openai_organization ,
242244 )
243- return openai_client
245+ return openai_client , azure_openai_endpoint
244246
245247
246248def setup_file_processors (
@@ -349,7 +351,7 @@ async def main(strategy: Strategy, setup_index: bool = True):
349351 await strategy .run ()
350352
351353
352- if __name__ == "__main__" :
354+ if __name__ == "__main__" : # pragma: no cover
353355 parser = argparse .ArgumentParser (
354356 description = "Prepare documents by extracting content from PDFs, splitting content into sections, uploading to blob storage, and indexing in a search index."
355357 )
@@ -500,7 +502,8 @@ async def main(strategy: Strategy, setup_index: bool = True):
500502 emb_model_dimensions = 1536
501503 if os .getenv ("AZURE_OPENAI_EMB_DIMENSIONS" ):
502504 emb_model_dimensions = int (os .environ ["AZURE_OPENAI_EMB_DIMENSIONS" ])
503- openai_client = setup_openai_client (
505+
506+ openai_client , azure_openai_endpoint = setup_openai_client (
504507 openai_host = OPENAI_HOST ,
505508 azure_credential = azd_credential ,
506509 azure_openai_service = os .getenv ("AZURE_OPENAI_SERVICE" ),
@@ -509,17 +512,13 @@ async def main(strategy: Strategy, setup_index: bool = True):
509512 openai_api_key = clean_key_if_exists (os .getenv ("OPENAI_API_KEY" )),
510513 openai_organization = os .getenv ("OPENAI_ORGANIZATION" ),
511514 )
512- azure_embedding_endpoint = os .getenv ("AZURE_OPENAI_ENDPOINT" ) or os .getenv ("AZURE_OPENAI_CUSTOM_URL" )
513- if not azure_embedding_endpoint and OPENAI_HOST == OpenAIHost .AZURE :
514- if service := os .getenv ("AZURE_OPENAI_SERVICE" ):
515- azure_embedding_endpoint = f"https://{ service } .openai.azure.com"
516515 openai_embeddings_service = setup_embeddings_service (
517516 open_ai_client = openai_client ,
518517 openai_host = OPENAI_HOST ,
519518 emb_model_name = os .environ ["AZURE_OPENAI_EMB_MODEL_NAME" ],
520519 emb_model_dimensions = emb_model_dimensions ,
521520 azure_openai_deployment = os .getenv ("AZURE_OPENAI_EMB_DEPLOYMENT" ),
522- azure_openai_endpoint = azure_embedding_endpoint ,
521+ azure_openai_endpoint = azure_openai_endpoint ,
523522 disable_vectors = dont_use_vectors ,
524523 disable_batch_vectors = args .disablebatchvectors ,
525524 )
0 commit comments