Skip to content

Commit 85791db

Browse files
authored
Vector search (#424)
Vector search
1 parent 6bfb2cc commit 85791db

18 files changed

+297
-92
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ It will look like the following:
7575
1. Run `azd env set AZURE_OPENAI_RESOURCE_GROUP {Name of existing resource group that OpenAI service is provisioned to}`
7676
1. Run `azd env set AZURE_OPENAI_CHATGPT_DEPLOYMENT {Name of existing ChatGPT deployment}`. Only needed if your ChatGPT deployment is not the default 'chat'.
7777
1. Run `azd env set AZURE_OPENAI_GPT_DEPLOYMENT {Name of existing GPT deployment}`. Only needed if your ChatGPT deployment is not the default 'davinci'.
78+
1. Run `azd env set AZURE_OPENAI_EMB_DEPLOYMENT {Name of existing GPT embedding deployment}`. Only needed if your embeddings deployment is not the default 'embedding'.
7879
1. Run `azd up`
7980

8081
> NOTE: You can also use existing Search and Storage Accounts. See `./infra/main.parameters.json` for list of environment variables to pass to `azd env set` to configure those existing resources.

app/backend/app.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AZURE_OPENAI_GPT_DEPLOYMENT = os.environ.get("AZURE_OPENAI_GPT_DEPLOYMENT") or "davinci"
2323
AZURE_OPENAI_CHATGPT_DEPLOYMENT = os.environ.get("AZURE_OPENAI_CHATGPT_DEPLOYMENT") or "chat"
2424
AZURE_OPENAI_CHATGPT_MODEL = os.environ.get("AZURE_OPENAI_CHATGPT_MODEL") or "gpt-35-turbo"
25+
AZURE_OPENAI_EMB_DEPLOYMENT = os.environ.get("AZURE_OPENAI_EMB_DEPLOYMENT") or "embedding"
2526

2627
KB_FIELDS_CONTENT = os.environ.get("KB_FIELDS_CONTENT") or "content"
2728
KB_FIELDS_CATEGORY = os.environ.get("KB_FIELDS_CATEGORY") or "category"
@@ -31,7 +32,7 @@
3132
# just use 'az login' locally, and managed identity when deployed on Azure). If you need to use keys, use separate AzureKeyCredential instances with the
3233
# keys for each service
3334
# If you encounter a blocking error during a DefaultAzureCredntial resolution, you can exclude the problematic credential by using a parameter (ex. exclude_shared_token_cache_credential=True)
34-
azure_credential = DefaultAzureCredential()
35+
azure_credential = DefaultAzureCredential(exclude_shared_token_cache_credential = True)
3536

3637
# Used by the OpenAI SDK
3738
openai.api_type = "azure"
@@ -56,13 +57,18 @@
5657
# Various approaches to integrate GPT and external knowledge, most applications will use a single one of these patterns
5758
# or some derivative, here we include several for exploration purposes
5859
ask_approaches = {
59-
"rtr": RetrieveThenReadApproach(search_client, AZURE_OPENAI_CHATGPT_DEPLOYMENT, AZURE_OPENAI_CHATGPT_MODEL, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
60-
"rrr": ReadRetrieveReadApproach(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
61-
"rda": ReadDecomposeAsk(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT)
60+
"rtr": RetrieveThenReadApproach(search_client, AZURE_OPENAI_CHATGPT_DEPLOYMENT, AZURE_OPENAI_CHATGPT_MODEL, AZURE_OPENAI_EMB_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
61+
"rrr": ReadRetrieveReadApproach(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, AZURE_OPENAI_EMB_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT),
62+
"rda": ReadDecomposeAsk(search_client, AZURE_OPENAI_GPT_DEPLOYMENT, AZURE_OPENAI_EMB_DEPLOYMENT, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT)
6263
}
6364

6465
chat_approaches = {
65-
"rrr": ChatReadRetrieveReadApproach(search_client, AZURE_OPENAI_CHATGPT_DEPLOYMENT, AZURE_OPENAI_CHATGPT_MODEL, KB_FIELDS_SOURCEPAGE, KB_FIELDS_CONTENT)
66+
"rrr": ChatReadRetrieveReadApproach(search_client,
67+
AZURE_OPENAI_CHATGPT_DEPLOYMENT,
68+
AZURE_OPENAI_CHATGPT_MODEL,
69+
AZURE_OPENAI_EMB_DEPLOYMENT,
70+
KB_FIELDS_SOURCEPAGE,
71+
KB_FIELDS_CONTENT)
6672
}
6773

6874
app = Flask(__name__)

app/backend/approaches/chatreadretrieveread.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,23 @@ class ChatReadRetrieveReadApproach(Approach):
2323
"""
2424
system_message_chat_conversation = """Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.
2525
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
26-
For tabular information return it as an html table. Do not return markdown format.
26+
For tabular information return it as an html table. Do not return markdown format. If the question is not in English, answer in the language used in the question.
2727
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, e.g. [info1.txt]. Don't combine sources, list each source separately, e.g. [info1.txt][info2.pdf].
2828
{follow_up_questions_prompt}
2929
{injected_prompt}
3030
"""
3131
follow_up_questions_prompt_content = """Generate three very brief follow-up questions that the user would likely ask next about their healthcare plan and employee handbook.
32-
Use double angle brackets to reference the questions, e.g. <<Are there exclusions for prescriptions?>>.
33-
Try not to repeat questions that have already been asked.
34-
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'"""
32+
Use double angle brackets to reference the questions, e.g. <<Are there exclusions for prescriptions?>>.
33+
Try not to repeat questions that have already been asked.
34+
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'"""
3535

3636
query_prompt_template = """Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching in a knowledge base about employee healthcare plans and the employee handbook.
37-
Generate a search query based on the conversation and the new question.
38-
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.
39-
Do not include any text inside [] or <<>> in the search query terms.
40-
Do not include any special characters like '+'.
41-
If the question is not in English, translate the question to English before generating the search query.
42-
43-
Search Query:
37+
Generate a search query based on the conversation and the new question.
38+
Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.
39+
Do not include any text inside [] or <<>> in the search query terms.
40+
Do not include any special characters like '+'.
41+
If the question is not in English, translate the question to English before generating the search query.
42+
If you cannot generate a search query, return just the number 0.
4443
"""
4544
query_prompt_few_shots = [
4645
{'role' : USER, 'content' : 'What are my health plans?' },
@@ -49,16 +48,19 @@ class ChatReadRetrieveReadApproach(Approach):
4948
{'role' : ASSISTANT, 'content' : 'Health plan cardio coverage' }
5049
]
5150

52-
def __init__(self, search_client: SearchClient, chatgpt_deployment: str, chatgpt_model: str, sourcepage_field: str, content_field: str):
51+
def __init__(self, search_client: SearchClient, chatgpt_deployment: str, chatgpt_model: str, embedding_deployment: str, sourcepage_field: str, content_field: str):
5352
self.search_client = search_client
5453
self.chatgpt_deployment = chatgpt_deployment
5554
self.chatgpt_model = chatgpt_model
55+
self.embedding_deployment = embedding_deployment
5656
self.sourcepage_field = sourcepage_field
5757
self.content_field = content_field
5858
self.chatgpt_token_limit = get_token_limit(chatgpt_model)
5959

6060
def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> Any:
61-
use_semantic_captions = True if overrides.get("semantic_captions") else False
61+
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
62+
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
63+
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
6264
top = overrides.get("top") or 3
6365
exclude_category = overrides.get("exclude_category") or None
6466
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
@@ -83,20 +85,42 @@ def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> A
8385
max_tokens=32,
8486
n=1)
8587

86-
q = chat_completion.choices[0].message.content
88+
query_text = chat_completion.choices[0].message.content
89+
if query_text.strip() == "0":
90+
query_text = history[-1]["user"] # Use the last user input if we failed to generate a better query
8791

8892
# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query
89-
if overrides.get("semantic_ranker"):
90-
r = self.search_client.search(q,
93+
94+
# If retrieval mode includes vectors, compute an embedding for the query
95+
if has_vector:
96+
query_vector = openai.Embedding.create(engine=self.embedding_deployment, input=query_text)["data"][0]["embedding"]
97+
else:
98+
query_vector = None
99+
100+
# Only keep the text query if the retrieval mode uses text, otherwise drop it
101+
if not has_text:
102+
query_text = None
103+
104+
# Use semantic L2 reranker if requested and if retrieval mode is text or hybrid (vectors + text)
105+
if overrides.get("semantic_ranker") and has_text:
106+
r = self.search_client.search(query_text,
91107
filter=filter,
92108
query_type=QueryType.SEMANTIC,
93109
query_language="en-us",
94110
query_speller="lexicon",
95111
semantic_configuration_name="default",
96112
top=top,
97-
query_caption="extractive|highlight-false" if use_semantic_captions else None)
113+
query_caption="extractive|highlight-false" if use_semantic_captions else None,
114+
vector=query_vector,
115+
top_k=50 if query_vector else None,
116+
vector_fields="embedding" if query_vector else None)
98117
else:
99-
r = self.search_client.search(q, filter=filter, top=top)
118+
r = self.search_client.search(query_text,
119+
filter=filter,
120+
top=top,
121+
vector=query_vector,
122+
top_k=50 if query_vector else None,
123+
vector_fields="embedding" if query_vector else None)
100124
if use_semantic_captions:
101125
results = [doc[self.sourcepage_field] + ": " + nonewlines(" . ".join([c.text for c in doc['@search.captions']])) for doc in r]
102126
else:
@@ -116,14 +140,11 @@ def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> A
116140
else:
117141
system_message = prompt_override.format(follow_up_questions_prompt=follow_up_questions_prompt)
118142

119-
# latest conversation
120-
user_content = history[-1]["user"] + " \nSources:" + content
121-
122143
messages = self.get_messages_from_history(
123-
system_message,
144+
system_message + "\n\nSources:\n" + content,
124145
self.chatgpt_model,
125146
history,
126-
user_content,
147+
history[-1]["user"],
127148
max_tokens=self.chatgpt_token_limit)
128149

129150
chat_completion = openai.ChatCompletion.create(
@@ -138,7 +159,7 @@ def run(self, history: Sequence[dict[str, str]], overrides: dict[str, Any]) -> A
138159

139160
msg_to_display = '\n\n'.join([str(message) for message in messages])
140161

141-
return {"data_points": results, "answer": chat_content, "thoughts": f"Searched for:<br>{q}<br><br>Conversations:<br>" + msg_to_display.replace('\n', '<br>')}
162+
return {"data_points": results, "answer": chat_content, "thoughts": f"Searched for:<br>{query_text}<br><br>Conversations:<br>" + msg_to_display.replace('\n', '<br>')}
142163

143164
def get_messages_from_history(self, system_prompt: str, model_id: str, history: Sequence[dict[str, str]], user_conv: str, few_shots = [], max_tokens: int = 4096) -> []:
144165
message_builder = MessageBuilder(system_prompt, model_id)

app/backend/approaches/readdecomposeask.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,50 @@
1313
from typing import Any, List, Optional
1414

1515
class ReadDecomposeAsk(Approach):
16-
def __init__(self, search_client: SearchClient, openai_deployment: str, sourcepage_field: str, content_field: str):
16+
def __init__(self, search_client: SearchClient, openai_deployment: str, embedding_deployment: str, sourcepage_field: str, content_field: str):
1717
self.search_client = search_client
1818
self.openai_deployment = openai_deployment
19+
self.embedding_deployment = embedding_deployment
1920
self.sourcepage_field = sourcepage_field
2021
self.content_field = content_field
2122

22-
def search(self, q: str, overrides: dict[str, Any]) -> str:
23-
use_semantic_captions = True if overrides.get("semantic_captions") else False
23+
def search(self, query_text: str, overrides: dict[str, Any]) -> str:
24+
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
25+
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
26+
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
2427
top = overrides.get("top") or 3
2528
exclude_category = overrides.get("exclude_category") or None
2629
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
2730

28-
if overrides.get("semantic_ranker"):
29-
r = self.search_client.search(q,
31+
# If retrieval mode includes vectors, compute an embedding for the query
32+
if has_vector:
33+
query_vector = openai.Embedding.create(engine=self.embedding_deployment, input=query_text)["data"][0]["embedding"]
34+
else:
35+
query_vector = None
36+
37+
# Only keep the text query if the retrieval mode uses text, otherwise drop it
38+
if not has_text:
39+
query_text = None
40+
41+
if overrides.get("semantic_ranker") and has_text:
42+
r = self.search_client.search(query_text,
3043
filter=filter,
3144
query_type=QueryType.SEMANTIC,
3245
query_language="en-us",
3346
query_speller="lexicon",
3447
semantic_configuration_name="default",
35-
top = top,
36-
query_caption="extractive|highlight-false" if use_semantic_captions else None)
48+
top=top,
49+
query_caption="extractive|highlight-false" if use_semantic_captions else None,
50+
vector=query_vector,
51+
top_k=50 if query_vector else None,
52+
vector_fields="embedding" if query_vector else None)
3753
else:
38-
r = self.search_client.search(q, filter=filter, top=top)
54+
r = self.search_client.search(query_text,
55+
filter=filter,
56+
top=top,
57+
vector=query_vector,
58+
top_k=50 if query_vector else None,
59+
vector_fields="embedding" if query_vector else None)
3960
if use_semantic_captions:
4061
self.results = [doc[self.sourcepage_field] + ":" + nonewlines(" . ".join([c.text for c in doc['@search.captions'] ])) for doc in r]
4162
else:

app/backend/approaches/readretrieveread.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,51 @@ class ReadRetrieveReadApproach(Approach):
4444

4545
CognitiveSearchToolDescription = "useful for searching the Microsoft employee benefits information such as healthcare plans, retirement plans, etc."
4646

47-
def __init__(self, search_client: SearchClient, openai_deployment: str, sourcepage_field: str, content_field: str):
47+
def __init__(self, search_client: SearchClient, openai_deployment: str, embedding_deployment: str, sourcepage_field: str, content_field: str):
4848
self.search_client = search_client
4949
self.openai_deployment = openai_deployment
50+
self.embedding_deployment = embedding_deployment
5051
self.sourcepage_field = sourcepage_field
5152
self.content_field = content_field
5253

53-
def retrieve(self, q: str, overrides: dict[str, Any]) -> Any:
54-
use_semantic_captions = True if overrides.get("semantic_captions") else False
54+
def retrieve(self, query_text: str, overrides: dict[str, Any]) -> Any:
55+
has_text = overrides.get("retrieval_mode") in ["text", "hybrid", None]
56+
has_vector = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
57+
use_semantic_captions = True if overrides.get("semantic_captions") and has_text else False
5558
top = overrides.get("top") or 3
5659
exclude_category = overrides.get("exclude_category") or None
5760
filter = "category ne '{}'".format(exclude_category.replace("'", "''")) if exclude_category else None
5861

59-
if overrides.get("semantic_ranker"):
60-
r = self.search_client.search(q,
62+
# If retrieval mode includes vectors, compute an embedding for the query
63+
if has_vector:
64+
query_vector = openai.Embedding.create(engine=self.embedding_deployment, input=query_text)["data"][0]["embedding"]
65+
else:
66+
query_vector = None
67+
68+
# Only keep the text query if the retrieval mode uses text, otherwise drop it
69+
if not has_text:
70+
query_text = None
71+
72+
# Use semantic ranker if requested and if retrieval mode is text or hybrid (vectors + text)
73+
if overrides.get("semantic_ranker") and has_text:
74+
r = self.search_client.search(query_text,
6175
filter=filter,
6276
query_type=QueryType.SEMANTIC,
6377
query_language="en-us",
6478
query_speller="lexicon",
6579
semantic_configuration_name="default",
6680
top = top,
67-
query_caption="extractive|highlight-false" if use_semantic_captions else None)
81+
query_caption="extractive|highlight-false" if use_semantic_captions else None,
82+
vector=query_vector,
83+
top_k=50 if query_vector else None,
84+
vector_fields="embedding" if query_vector else None)
6885
else:
69-
r = self.search_client.search(q, filter=filter, top=top)
86+
r = self.search_client.search(query_text,
87+
filter=filter,
88+
top=top,
89+
vector=query_vector,
90+
top_k=50 if query_vector else None,
91+
vector_fields="embedding" if query_vector else None)
7092
if use_semantic_captions:
7193
self.results = [doc[self.sourcepage_field] + ":" + nonewlines(" -.- ".join([c.text for c in doc['@search.captions']])) for doc in r]
7294
else:

0 commit comments

Comments
 (0)