@@ -743,16 +743,11 @@ def get_embedding(text, embedding_model_endpoint=None, embedding_model_key=None,
743743 FLAG_COHERE = os .getenv ("FLAG_COHERE" , "ENGLISH" )
744744 FLAG_AOAI = os .getenv ("FLAG_AOAI" , "V3" )
745745
746- if azure_credential is None and (endpoint is None or key is None ):
746+ if azure_credential is None and (endpoint is None ):
747747 raise Exception ("EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding" )
748748
749749 try :
750750 if FLAG_EMBEDDING_MODEL == "AOAI" :
751- # endpoint_parts = endpoint.split("/openai/deployments/")
752- # base_url = endpoint_parts[0]
753- # deployment_id = endpoint_parts[1].split("/embeddings")[0]
754- # api_version = endpoint_parts[1].split("api-version=")[1].split("&")[0]
755-
756751 deployment_id = "embedding"
757752 api_version = "2024-02-01"
758753
@@ -761,7 +756,7 @@ def get_embedding(text, embedding_model_endpoint=None, embedding_model_key=None,
761756 else :
762757 api_key = embedding_model_key if embedding_model_key else os .getenv ("AZURE_OPENAI_API_KEY" )
763758
764- client = AzureOpenAI (api_version = api_version , azure_endpoint = "https://cog-generic-accelerator-dev.openai.azure.com/" , api_key = api_key )
759+ client = AzureOpenAI (api_version = api_version , azure_endpoint = endpoint , api_key = api_key )
765760 embeddings = client .embeddings .create (model = deployment_id , input = text )
766761
767762 return embeddings .model_dump ()['data' ][0 ]['embedding' ]
0 commit comments