Skip to content

Commit 15af3a8

Browse files
authored
Improve uses of get, fix temperature bug (#1225)
* Improve uses of get, fix temperature bug * Mention temperature in comments/docs
1 parent 270d869 commit 15af3a8

File tree

9 files changed

+20
-20
lines changed

9 files changed

+20
-20
lines changed

app/backend/approaches/approach.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(
9696
self.openai_host = openai_host
9797

9898
def build_filter(self, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> Optional[str]:
99-
exclude_category = overrides.get("exclude_category") or None
99+
exclude_category = overrides.get("exclude_category")
100100
security_filter = self.auth_helper.build_security_filters(overrides, auth_claims)
101101
filters = []
102102
if exclude_category:

app/backend/approaches/chatreadretrieveread.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ async def run_until_final_call(
129129
messages=messages, # type: ignore
130130
# Azure Open AI takes the deployment name as the model name
131131
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
132-
temperature=0.0,
132+
temperature=0.0, # Minimize creativity for search query generation
133133
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance
134134
n=1,
135135
tools=tools,
@@ -196,7 +196,7 @@ async def run_until_final_call(
196196
# Azure Open AI takes the deployment name as the model name
197197
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
198198
messages=messages,
199-
temperature=overrides.get("temperature") or 0.7,
199+
temperature=overrides.get("temperature", 0.7),
200200
max_tokens=response_token_limit,
201201
n=1,
202202
stream=should_stream,

app/backend/approaches/chatreadretrievereadvision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def run_until_final_call(
110110
chat_completion: ChatCompletion = await self.openai_client.chat.completions.create(
111111
model=self.gpt4v_deployment if self.gpt4v_deployment else self.gpt4v_model,
112112
messages=messages,
113-
temperature=overrides.get("temperature") or 0.0,
113+
temperature=0.0, # Minimize creativity for search query generation
114114
max_tokens=100,
115115
n=1,
116116
)
@@ -194,7 +194,7 @@ async def run_until_final_call(
194194
chat_coroutine = self.openai_client.chat.completions.create(
195195
model=self.gpt4v_deployment if self.gpt4v_deployment else self.gpt4v_model,
196196
messages=messages,
197-
temperature=overrides.get("temperature") or 0.7,
197+
temperature=overrides.get("temperature", 0.7),
198198
max_tokens=response_token_limit,
199199
n=1,
200200
stream=should_stream,

app/backend/approaches/retrievethenread.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def run(
9999

100100
user_content = [q]
101101

102-
template = overrides.get("prompt_template") or self.system_chat_template
102+
template = overrides.get("prompt_template", self.system_chat_template)
103103
model = self.chatgpt_model
104104
message_builder = MessageBuilder(template, model)
105105

@@ -118,7 +118,7 @@ async def run(
118118
# Azure Open AI takes the deployment name as the model name
119119
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
120120
messages=message_builder.messages,
121-
temperature=overrides.get("temperature") or 0.3,
121+
temperature=overrides.get("temperature", 0.3),
122122
max_tokens=1024,
123123
n=1,
124124
)

app/backend/approaches/retrievethenreadvision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def run(
112112
image_list: list[ChatCompletionContentPartImageParam] = []
113113
user_content: list[ChatCompletionContentPartParam] = [{"text": q, "type": "text"}]
114114

115-
template = overrides.get("prompt_template") or (self.system_chat_template_gpt4v)
115+
template = overrides.get("prompt_template", self.system_chat_template_gpt4v)
116116
model = self.gpt4v_model
117117
message_builder = MessageBuilder(template, model)
118118

@@ -137,7 +137,7 @@ async def run(
137137
await self.openai_client.chat.completions.create(
138138
model=self.gpt4v_deployment if self.gpt4v_deployment else self.gpt4v_model,
139139
messages=message_builder.messages,
140-
temperature=overrides.get("temperature") or 0.3,
140+
temperature=overrides.get("temperature", 0.3),
141141
max_tokens=1024,
142142
n=1,
143143
)

app/backend/core/authentication.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_auth_setup_for_client(self) -> dict[str, Any]:
106106
@staticmethod
107107
def get_token_auth_header(headers: dict) -> str:
108108
# Obtains the Access Token from the Authorization Header
109-
auth = headers.get("Authorization", None)
109+
auth = headers.get("Authorization")
110110
if auth:
111111
parts = auth.split()
112112

@@ -122,7 +122,7 @@ def get_token_auth_header(headers: dict) -> str:
122122

123123
# App services built-in authentication passes the access token directly as a header
124124
# To learn more, please visit https://learn.microsoft.com/azure/app-service/configure-authentication-oauth-tokens
125-
token = headers.get("x-ms-token-aad-access-token", None)
125+
token = headers.get("x-ms-token-aad-access-token")
126126
if token:
127127
return token
128128

@@ -141,10 +141,10 @@ def build_security_filters(self, overrides: dict[str, Any], auth_claims: dict[st
141141
)
142142

143143
oid_security_filter = (
144-
"oids/any(g:search.in(g, '{}'))".format(auth_claims.get("oid") or "") if use_oid_security_filter else None
144+
"oids/any(g:search.in(g, '{}'))".format(auth_claims.get("oid", "")) if use_oid_security_filter else None
145145
)
146146
groups_security_filter = (
147-
"groups/any(g:search.in(g, '{}'))".format(", ".join(auth_claims.get("groups") or []))
147+
"groups/any(g:search.in(g, '{}'))".format(", ".join(auth_claims.get("groups", [])))
148148
if use_groups_security_filter
149149
else None
150150
)
@@ -212,7 +212,7 @@ async def get_auth_claims_if_enabled(self, headers: dict) -> dict[str, Any]:
212212
# Read the claims from the response. The oid and groups claims are used for security filtering
213213
# https://learn.microsoft.com/azure/active-directory/develop/id-token-claims-reference
214214
id_token_claims = graph_resource_access_token["id_token_claims"]
215-
auth_claims = {"oid": id_token_claims["oid"], "groups": id_token_claims.get("groups") or []}
215+
auth_claims = {"oid": id_token_claims["oid"], "groups": id_token_claims.get("groups", [])}
216216

217217
# A groups claim may have been omitted either because it was not added in the application manifest for the API application,
218218
# or a groups overage claim may have been emitted.

app/backend/core/modelhelper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ def get_oai_chatmodel_tiktok(aoaimodel: str) -> str:
5656
raise ValueError(message)
5757
if aoaimodel not in AOAI_2_OAI and aoaimodel not in MODELS_2_TOKEN_LIMITS:
5858
raise ValueError(message)
59-
return AOAI_2_OAI.get(aoaimodel) or aoaimodel
59+
return AOAI_2_OAI.get(aoaimodel, aoaimodel)

docs/customization.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ Typically, the primary backend code you'll want to customize is the `app/backend
3737

3838
The chat tab uses the approach programmed in [chatreadretrieveread.py](https://github.com/Azure-Samples/azure-search-openai-demo/blob/main/app/backend/approaches/chatreadretrieveread.py).
3939

40-
1. It uses the OpenAI ChatCompletion API to turn the user question into a good search query.
40+
1. It calls the OpenAI ChatCompletion API (with a temperature of 0) to turn the user question into a good search query.
4141
2. It queries Azure AI Search for search results for that query (optionally using the vector embeddings for that query).
42-
3. It then combines the search results and original user question, and asks OpenAI ChatCompletion API to answer the question based on the sources. It includes the last 4K of message history as well (or however many tokens are allowed by the deployed model).
42+
3. It then combines the search results and original user question, and calls the OpenAI ChatCompletion API (with a temperature of 0.7) to answer the question based on the sources. It includes the last 4K of message history as well (or however many tokens are allowed by the deployed model).
4343

4444
The `system_message_chat_conversation` variable is currently tailored to the sample data since it starts with "Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook." Change that to match your data.
4545

@@ -56,7 +56,7 @@ If you followed the instructions in [docs/gpt4v.md](docs/gpt4v.md) to enable the
5656
The ask tab uses the approach programmed in [retrievethenread.py](https://github.com/Azure-Samples/azure-search-openai-demo/blob/main/app/backend/approaches/retrievethenread.py).
5757

5858
1. It queries Azure AI Search for search results for the user question (optionally using the vector embeddings for that question).
59-
2. It then combines the search results and user question, and asks OpenAI ChatCompletion API to answer the question based on the sources.
59+
2. It then combines the search results and user question, and calls the OpenAI ChatCompletion API (with a temperature of 0.3) to answer the question based on the sources.
6060

6161
The `system_chat_template` variable is currently tailored to the sample data since it starts with "You are an intelligent assistant helping Contoso Inc employees with their healthcare plan questions and employee handbook questions." Change that to match your data.
6262

scripts/prepdocslib/filestrategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ async def run(self, search_info: SearchInfo):
5959
async for file in files:
6060
try:
6161
key = file.file_extension()
62-
processor = self.file_processors[key]
63-
if not processor:
62+
processor = self.file_processors.get(key)
63+
if processor is None:
6464
# skip file if no parser is found
6565
if search_info.verbose:
6666
print(f"Skipping '{file.filename()}'.")

0 commit comments

Comments
 (0)