|
| 1 | +import os |
| 2 | +from dotenv import load_dotenv |
| 3 | + |
| 4 | +load_dotenv() |
| 5 | + |
| 6 | +from promptflow.core import Prompty, AzureOpenAIModelConfiguration |
| 7 | +from promptflow.tracing import trace |
| 8 | +from openai import AzureOpenAI |
| 9 | + |
| 10 | +# <get_documents> |
| 11 | +@trace |
| 12 | +def get_documents(search_query: str, num_docs=3): |
| 13 | + from azure.identity import DefaultAzureCredential, get_bearer_token_provider |
| 14 | + from azure.search.documents import SearchClient |
| 15 | + from azure.search.documents.models import VectorizedQuery |
| 16 | + |
| 17 | + token_provider = get_bearer_token_provider( |
| 18 | + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" |
| 19 | + ) |
| 20 | + |
| 21 | + index_name = os.getenv("AZUREAI_SEARCH_INDEX_NAME") |
| 22 | + |
| 23 | + # retrieve documents relevant to the user's question from Cognitive Search |
| 24 | + search_client = SearchClient( |
| 25 | + endpoint=os.getenv("AZURE_SEARCH_ENDPOINT"), |
| 26 | + credential=DefaultAzureCredential(), |
| 27 | + index_name=index_name, |
| 28 | + ) |
| 29 | + |
| 30 | + aoai_client = AzureOpenAI( |
| 31 | + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), |
| 32 | + azure_ad_token_provider=token_provider, |
| 33 | + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), |
| 34 | + ) |
| 35 | + |
| 36 | + # generate a vector embedding of the user's question |
| 37 | + embedding = aoai_client.embeddings.create( |
| 38 | + input=search_query, model=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") |
| 39 | + ) |
| 40 | + embedding_to_query = embedding.data[0].embedding |
| 41 | + |
| 42 | + context = "" |
| 43 | + # use the vector embedding to do a vector search on the index |
| 44 | + vector_query = VectorizedQuery( |
| 45 | + vector=embedding_to_query, k_nearest_neighbors=num_docs, fields="contentVector" |
| 46 | + ) |
| 47 | + results = trace(search_client.search)( |
| 48 | + search_text="", vector_queries=[vector_query], select=["id", "content"] |
| 49 | + ) |
| 50 | + |
| 51 | + for result in results: |
| 52 | + context += f"\n>>> From: {result['id']}\n{result['content']}" |
| 53 | + |
| 54 | + return context |
| 55 | + |
| 56 | + |
| 57 | +# <get_documents> |
| 58 | + |
| 59 | +from promptflow.core import Prompty, AzureOpenAIModelConfiguration |
| 60 | + |
| 61 | +from pathlib import Path |
| 62 | +from typing import TypedDict |
| 63 | + |
| 64 | + |
| 65 | +class ChatResponse(TypedDict): |
| 66 | + context: dict |
| 67 | + reply: str |
| 68 | + |
| 69 | + |
| 70 | +def get_chat_response(chat_input: str, chat_history: list = []) -> ChatResponse: |
| 71 | + model_config = AzureOpenAIModelConfiguration( |
| 72 | + azure_deployment=os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT"), |
| 73 | + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), |
| 74 | + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), |
| 75 | + ) |
| 76 | + |
| 77 | + searchQuery = chat_input |
| 78 | + |
| 79 | + # Only extract intent if there is chat_history |
| 80 | + if len(chat_history) > 0: |
| 81 | + # extract current query intent given chat_history |
| 82 | + path_to_prompty = f"{Path(__file__).parent.absolute().as_posix()}/queryIntent.prompty" # pass absolute file path to prompty |
| 83 | + intentPrompty = Prompty.load( |
| 84 | + path_to_prompty, |
| 85 | + model={ |
| 86 | + "configuration": model_config, |
| 87 | + "parameters": { |
| 88 | + "max_tokens": 256, |
| 89 | + }, |
| 90 | + }, |
| 91 | + ) |
| 92 | + searchQuery = intentPrompty(query=chat_input, chat_history=chat_history) |
| 93 | + |
| 94 | + # retrieve relevant documents and context given chat_history and current user query (chat_input) |
| 95 | + documents = get_documents(searchQuery, 3) |
| 96 | + |
| 97 | + # send query + document context to chat completion for a response |
| 98 | + path_to_prompty = f"{Path(__file__).parent.absolute().as_posix()}/chat.prompty" |
| 99 | + chatPrompty = Prompty.load( |
| 100 | + path_to_prompty, |
| 101 | + model={ |
| 102 | + "configuration": model_config, |
| 103 | + "parameters": {"max_tokens": 256, "temperature": 0.2}, |
| 104 | + }, |
| 105 | + ) |
| 106 | + result = chatPrompty( |
| 107 | + chat_history=chat_history, chat_input=chat_input, documents=documents |
| 108 | + ) |
| 109 | + |
| 110 | + return dict(reply=result, context=documents) |
0 commit comments