|
13 | 13 | from dataclasses import dataclass |
14 | 14 | from functools import partial |
15 | 15 | from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union |
| 16 | +from urllib.parse import urlparse |
16 | 17 |
|
17 | 18 | import fitz |
18 | 19 | import markdown |
19 | 20 | import requests |
20 | 21 | import tiktoken |
21 | 22 | from azure.ai.documentintelligence import DocumentIntelligenceClient |
22 | 23 | from azure.ai.documentintelligence.models import AnalyzeDocumentRequest |
| 24 | +from azure.ai.inference import EmbeddingsClient |
23 | 25 | from azure.core.credentials import AzureKeyCredential |
| 26 | +from azure.identity import AzureCliCredential |
| 27 | +from azure.keyvault.secrets import SecretClient |
24 | 28 | from azure.storage.blob import ContainerClient |
25 | 29 | from bs4 import BeautifulSoup |
26 | 30 | from dotenv import load_dotenv |
27 | 31 | from langchain.text_splitter import (MarkdownTextSplitter, |
28 | 32 | PythonCodeTextSplitter, |
29 | 33 | RecursiveCharacterTextSplitter, |
30 | 34 | TextSplitter) |
31 | | -from openai import AzureOpenAI |
32 | 35 | from tqdm import tqdm |
33 | 36 |
|
34 | 37 | # Configure environment variables |
35 | 38 | load_dotenv() # take environment variables from .env. |
36 | 39 |
|
| 40 | +# Key Vault name - replaced during deployment |
| 41 | +key_vault_name = 'kv_to-be-replaced' |
| 42 | + |
| 43 | + |
| 44 | +def get_secrets_from_kv(secret_name: str) -> str: |
| 45 | + """Retrieves a secret value from Azure Key Vault. |
| 46 | +
|
| 47 | + Args: |
| 48 | + secret_name: Name of the secret |
| 49 | +
|
| 50 | + Returns: |
| 51 | + The secret value |
| 52 | + """ |
| 53 | + kv_credential = AzureCliCredential() |
| 54 | + secret_client = SecretClient( |
| 55 | + vault_url=f"https://{key_vault_name}.vault.azure.net/", |
| 56 | + credential=kv_credential |
| 57 | + ) |
| 58 | + return secret_client.get_secret(secret_name).value |
| 59 | + |
| 60 | + |
37 | 61 | FILE_FORMAT_DICT = { |
38 | 62 | "md": "markdown", |
39 | 63 | "txt": "text", |
@@ -825,45 +849,27 @@ def get_payload_and_headers_cohere(text, aad_token) -> Tuple[Dict, Dict]: |
825 | 849 | def get_embedding( |
826 | 850 | text, embedding_model_endpoint=None, embedding_model_key=None, azure_credential=None |
827 | 851 | ): |
828 | | - endpoint = ( |
829 | | - embedding_model_endpoint |
830 | | - if embedding_model_endpoint |
831 | | - else os.environ.get("EMBEDDING_MODEL_ENDPOINT") |
832 | | - ) |
| 852 | + # Get AI Project endpoint from Key Vault |
| 853 | + ai_project_endpoint = get_secrets_from_kv("AZURE-AI-AGENT-ENDPOINT") |
833 | 854 |
|
834 | | - FLAG_EMBEDDING_MODEL = os.getenv("FLAG_EMBEDDING_MODEL", "AOAI") |
835 | | - |
836 | | - if azure_credential is None and (endpoint is None): |
837 | | - raise Exception( |
838 | | - "EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding" |
839 | | - ) |
| 855 | + # Construct inference endpoint: https://aif-xyz.services.ai.azure.com/models |
| 856 | + inference_endpoint = f"https://{urlparse(ai_project_endpoint).netloc}/models" |
| 857 | + embedding_model = "text-embedding-ada-002" |
840 | 858 |
|
841 | 859 | try: |
842 | | - if FLAG_EMBEDDING_MODEL == "AOAI": |
843 | | - deployment_id = "embedding" |
844 | | - api_version = "2024-02-01" |
845 | | - |
846 | | - if azure_credential is not None: |
847 | | - api_key = azure_credential.get_token( |
848 | | - "https://cognitiveservices.azure.com/.default" |
849 | | - ).token |
850 | | - else: |
851 | | - api_key = ( |
852 | | - embedding_model_key |
853 | | - if embedding_model_key |
854 | | - else os.getenv("AZURE_OPENAI_API_KEY") |
855 | | - ) |
856 | | - |
857 | | - client = AzureOpenAI( |
858 | | - api_version=api_version, azure_endpoint=endpoint, api_key=api_key |
859 | | - ) |
860 | | - embeddings = client.embeddings.create(model=deployment_id, input=text) |
| 860 | + credential = azure_credential if azure_credential is not None else AzureCliCredential() |
| 861 | + embeddings_client = EmbeddingsClient( |
| 862 | + endpoint=inference_endpoint, |
| 863 | + credential=credential, |
| 864 | + credential_scopes=["https://cognitiveservices.azure.com/.default"] |
| 865 | + ) |
861 | 866 |
|
862 | | - return embeddings.model_dump()["data"][0]["embedding"] |
| 867 | + response = embeddings_client.embed(model=embedding_model, input=[text]) |
| 868 | + return response.data[0].embedding |
863 | 869 |
|
864 | 870 | except Exception as e: |
865 | 871 | raise Exception( |
866 | | - f"Error getting embeddings with endpoint={endpoint} with error={e}" |
| 872 | + f"Error getting embeddings with endpoint={inference_endpoint} with error={e}" |
867 | 873 | ) |
868 | 874 |
|
869 | 875 |
|
|
0 commit comments