Skip to content

Commit 604c181

Browse files
refactor: Replcae OpenAI API calls with SDK API call
2 parents 30271c4 + cdab8dc commit 604c181

File tree

6 files changed

+64
-52
lines changed

6 files changed

+64
-52
lines changed

infra/main.bicep

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,11 @@ module keyvault 'br/public:avm/res/key-vault/vault:0.12.1' = {
11051105
{name: 'AZURE-OPENAI-PREVIEW-API-VERSION', value: azureOpenaiAPIVersion}
11061106
{name: 'AZURE-OPEN-AI-DEPLOYMENT-MODEL', value: gptModelName}
11071107
{name: 'TENANT-ID', value: subscription().tenantId}
1108+
{
1109+
name: 'AZURE-AI-AGENT-ENDPOINT'
1110+
value: aiFoundryAiProjectEndpoint
1111+
}
1112+
11081113
]
11091114
}
11101115
dependsOn:[

infra/scripts/index_scripts/02_process_data.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from azure.keyvault.secrets import SecretClient
2-
from openai import AzureOpenAI
2+
from azure.ai.inference import EmbeddingsClient
33
import re
44
import time
55
import pypdf
66
from io import BytesIO
7+
from urllib.parse import urlparse
78
from azure.search.documents import SearchClient
89
from azure.storage.filedatalake import DataLakeServiceClient
910
from azure.search.documents.indexes import SearchIndexClient
10-
from azure.identity import (AzureCliCredential, get_bearer_token_provider)
11+
from azure.identity import AzureCliCredential
1112

1213

1314
key_vault_name = 'kv_to-be-replaced'
@@ -36,9 +37,7 @@ def get_secrets_from_kv(secret_name: str) -> str:
3637

3738
# Retrieve secrets from Key Vault
3839
search_endpoint = get_secrets_from_kv("AZURE-SEARCH-ENDPOINT")
39-
openai_api_base = get_secrets_from_kv("AZURE-OPENAI-ENDPOINT")
40-
openai_api_version = get_secrets_from_kv("AZURE-OPENAI-PREVIEW-API-VERSION")
41-
deployment = get_secrets_from_kv("AZURE-OPEN-AI-DEPLOYMENT-MODEL")
40+
ai_project_endpoint = get_secrets_from_kv("AZURE-AI-AGENT-ENDPOINT")
4241
account_name = get_secrets_from_kv("ADLS-ACCOUNT-NAME")
4342
print("Secrets retrieved from Key Vault.")
4443

@@ -58,18 +57,19 @@ def get_secrets_from_kv(secret_name: str) -> str:
5857

5958

6059
# Function: Get Embeddings
61-
def get_embeddings(text: str, openai_api_base, openai_api_version):
62-
model_id = "text-embedding-ada-002"
63-
ad_token_provider = get_bearer_token_provider(
64-
credential, "https://cognitiveservices.azure.com/.default"
65-
)
66-
client = AzureOpenAI(
67-
api_version=openai_api_version,
68-
azure_endpoint=openai_api_base,
69-
azure_ad_token_provider=ad_token_provider
60+
def get_embeddings(text: str, ai_project_endpoint: str):
61+
embedding_model = "text-embedding-ada-002"
62+
# Construct inference endpoint with /models path
63+
inference_endpoint = f"https://{urlparse(ai_project_endpoint).netloc}/models"
64+
65+
embeddings_client = EmbeddingsClient(
66+
endpoint=inference_endpoint,
67+
credential=credential,
68+
credential_scopes=["https://cognitiveservices.azure.com/.default"]
7069
)
7170

72-
embedding = client.embeddings.create(input=text, model=model_id).data[0].embedding
71+
response = embeddings_client.embed(model=embedding_model, input=[text])
72+
embedding = response.data[0].embedding
7373
return embedding
7474

7575

@@ -126,12 +126,12 @@ def prepare_search_doc(content, document_id):
126126
chunk_id = f"{document_id}_{str(idx).zfill(2)}"
127127

128128
try:
129-
v_contentVector = get_embeddings(str(chunk), openai_api_base, openai_api_version)
129+
v_contentVector = get_embeddings(str(chunk), ai_project_endpoint)
130130
except Exception as e:
131131
print(f"Error occurred: {e}. Retrying after 30 seconds...")
132132
time.sleep(30)
133133
try:
134-
v_contentVector = get_embeddings(str(chunk), openai_api_base, openai_api_version)
134+
v_contentVector = get_embeddings(str(chunk), ai_project_endpoint)
135135
except Exception as e:
136136
print(f"Retry failed: {e}. Setting v_contentVector to an empty list.")
137137
v_contentVector = []

infra/scripts/index_scripts/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
azure-storage-file-datalake==12.20.0
2-
openai==1.84.0
2+
azure-ai-inference==1.0.0b9
33
pypdf==5.6.0
44
# pyodbc
55
tiktoken==0.9.0

infra/scripts/run_create_index_scripts.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ fi
119119
#Replace key vault name
120120
sed -i "s/kv_to-be-replaced/${keyvaultName}/g" "infra/scripts/index_scripts/01_create_search_index.py"
121121
sed -i "s/kv_to-be-replaced/${keyvaultName}/g" "infra/scripts/index_scripts/02_process_data.py"
122+
sed -i "s/kv_to-be-replaced/${keyvaultName}/g" "scripts/data_utils.py"
122123
if [ -n "$managedIdentityClientId" ]; then
123124
sed -i "s/mici_to-be-replaced/${managedIdentityClientId}/g" "infra/scripts/index_scripts/01_create_search_index.py"
124125
sed -i "s/mici_to-be-replaced/${managedIdentityClientId}/g" "infra/scripts/index_scripts/02_process_data.py"
@@ -181,6 +182,7 @@ fi
181182
# revert the key vault name and managed identity client id in the python files
182183
sed -i "s/${keyvaultName}/kv_to-be-replaced/g" "infra/scripts/index_scripts/01_create_search_index.py"
183184
sed -i "s/${keyvaultName}/kv_to-be-replaced/g" "infra/scripts/index_scripts/02_process_data.py"
185+
sed -i "s/${keyvaultName}/kv_to-be-replaced/g" "scripts/data_utils.py"
184186
if [ -n "$managedIdentityClientId" ]; then
185187
sed -i "s/${managedIdentityClientId}/mici_to-be-replaced/g" "infra/scripts/index_scripts/01_create_search_index.py"
186188
sed -i "s/${managedIdentityClientId}/mici_to-be-replaced/g" "infra/scripts/index_scripts/02_process_data.py"

scripts/data_utils.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,51 @@
1313
from dataclasses import dataclass
1414
from functools import partial
1515
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
16+
from urllib.parse import urlparse
1617

1718
import fitz
1819
import markdown
1920
import requests
2021
import tiktoken
2122
from azure.ai.documentintelligence import DocumentIntelligenceClient
2223
from azure.ai.documentintelligence.models import AnalyzeDocumentRequest
24+
from azure.ai.inference import EmbeddingsClient
2325
from azure.core.credentials import AzureKeyCredential
26+
from azure.identity import AzureCliCredential
27+
from azure.keyvault.secrets import SecretClient
2428
from azure.storage.blob import ContainerClient
2529
from bs4 import BeautifulSoup
2630
from dotenv import load_dotenv
2731
from langchain.text_splitter import (MarkdownTextSplitter,
2832
PythonCodeTextSplitter,
2933
RecursiveCharacterTextSplitter,
3034
TextSplitter)
31-
from openai import AzureOpenAI
3235
from tqdm import tqdm
3336

3437
# Configure environment variables
3538
load_dotenv() # take environment variables from .env.
3639

40+
# Key Vault name - replaced during deployment
41+
key_vault_name = 'kv_to-be-replaced'
42+
43+
44+
def get_secrets_from_kv(secret_name: str) -> str:
45+
"""Retrieves a secret value from Azure Key Vault.
46+
47+
Args:
48+
secret_name: Name of the secret
49+
50+
Returns:
51+
The secret value
52+
"""
53+
kv_credential = AzureCliCredential()
54+
secret_client = SecretClient(
55+
vault_url=f"https://{key_vault_name}.vault.azure.net/",
56+
credential=kv_credential
57+
)
58+
return secret_client.get_secret(secret_name).value
59+
60+
3761
FILE_FORMAT_DICT = {
3862
"md": "markdown",
3963
"txt": "text",
@@ -825,45 +849,27 @@ def get_payload_and_headers_cohere(text, aad_token) -> Tuple[Dict, Dict]:
825849
def get_embedding(
826850
text, embedding_model_endpoint=None, embedding_model_key=None, azure_credential=None
827851
):
828-
endpoint = (
829-
embedding_model_endpoint
830-
if embedding_model_endpoint
831-
else os.environ.get("EMBEDDING_MODEL_ENDPOINT")
832-
)
852+
# Get AI Project endpoint from Key Vault
853+
ai_project_endpoint = get_secrets_from_kv("AZURE-AI-AGENT-ENDPOINT")
833854

834-
FLAG_EMBEDDING_MODEL = os.getenv("FLAG_EMBEDDING_MODEL", "AOAI")
835-
836-
if azure_credential is None and (endpoint is None):
837-
raise Exception(
838-
"EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding"
839-
)
855+
# Construct inference endpoint: https://aif-xyz.services.ai.azure.com/models
856+
inference_endpoint = f"https://{urlparse(ai_project_endpoint).netloc}/models"
857+
embedding_model = "text-embedding-ada-002"
840858

841859
try:
842-
if FLAG_EMBEDDING_MODEL == "AOAI":
843-
deployment_id = "embedding"
844-
api_version = "2024-02-01"
845-
846-
if azure_credential is not None:
847-
api_key = azure_credential.get_token(
848-
"https://cognitiveservices.azure.com/.default"
849-
).token
850-
else:
851-
api_key = (
852-
embedding_model_key
853-
if embedding_model_key
854-
else os.getenv("AZURE_OPENAI_API_KEY")
855-
)
856-
857-
client = AzureOpenAI(
858-
api_version=api_version, azure_endpoint=endpoint, api_key=api_key
859-
)
860-
embeddings = client.embeddings.create(model=deployment_id, input=text)
860+
credential = azure_credential if azure_credential is not None else AzureCliCredential()
861+
embeddings_client = EmbeddingsClient(
862+
endpoint=inference_endpoint,
863+
credential=credential,
864+
credential_scopes=["https://cognitiveservices.azure.com/.default"]
865+
)
861866

862-
return embeddings.model_dump()["data"][0]["embedding"]
867+
response = embeddings_client.embed(model=embedding_model, input=[text])
868+
return response.data[0].embedding
863869

864870
except Exception as e:
865871
raise Exception(
866-
f"Error getting embeddings with endpoint={endpoint} with error={e}"
872+
f"Error getting embeddings with endpoint={inference_endpoint} with error={e}"
867873
)
868874

869875

src/requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
azure-identity==1.25.0
22
# Flask[async]==2.3.2
3-
openai==2.0.1
43
azure-search-documents==11.7.0b1
54
azure-storage-blob==12.26.0
65
python-dotenv==1.1.1

0 commit comments

Comments
 (0)