Skip to content

Commit 96a8417

Browse files
authored
Add an optional seed parameter (#1814)
* Add seed parameter * Assert seed is None in tests * Add tests * Add tests snapshots
1 parent c273398 commit 96a8417

File tree

12 files changed

+243
-2
lines changed

12 files changed

+243
-2
lines changed

app/backend/approaches/chatreadretrieveread.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ async def run_until_final_call(
8888
auth_claims: dict[str, Any],
8989
should_stream: bool = False,
9090
) -> tuple[dict[str, Any], Coroutine[Any, Any, Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]]]:
91+
seed = overrides.get("seed", None)
9192
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
9293
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
9394
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
@@ -142,6 +143,7 @@ async def run_until_final_call(
142143
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, setting too high may affect performance
143144
n=1,
144145
tools=tools,
146+
seed=seed,
145147
)
146148

147149
query_text = self.get_search_query(chat_completion, original_user_query)
@@ -237,5 +239,6 @@ async def run_until_final_call(
237239
max_tokens=response_token_limit,
238240
n=1,
239241
stream=should_stream,
242+
seed=seed,
240243
)
241244
return (extra_info, chat_coroutine)

app/backend/approaches/chatreadretrievereadvision.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ async def run_until_final_call(
8989
auth_claims: dict[str, Any],
9090
should_stream: bool = False,
9191
) -> tuple[dict[str, Any], Coroutine[Any, Any, Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]]]:
92+
seed = overrides.get("seed", None)
9293
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
9394
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
9495
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
@@ -128,6 +129,7 @@ async def run_until_final_call(
128129
temperature=0.0, # Minimize creativity for search query generation
129130
max_tokens=query_response_token_limit,
130131
n=1,
132+
seed=seed,
131133
)
132134

133135
query_text = self.get_search_query(chat_completion, original_user_query)
@@ -241,5 +243,6 @@ async def run_until_final_call(
241243
max_tokens=response_token_limit,
242244
n=1,
243245
stream=should_stream,
246+
seed=seed,
244247
)
245248
return (extra_info, chat_coroutine)

app/backend/approaches/retrievethenread.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ async def run(
7979
if not isinstance(q, str):
8080
raise ValueError("The most recent message content must be a string.")
8181
overrides = context.get("overrides", {})
82+
seed = overrides.get("seed", None)
8283
auth_claims = context.get("auth_claims", {})
8384
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
8485
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
@@ -131,6 +132,7 @@ async def run(
131132
temperature=overrides.get("temperature", 0.3),
132133
max_tokens=response_token_limit,
133134
n=1,
135+
seed=seed,
134136
)
135137
).model_dump()
136138

app/backend/approaches/retrievethenreadvision.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ async def run(
8080
raise ValueError("The most recent message content must be a string.")
8181

8282
overrides = context.get("overrides", {})
83+
seed = overrides.get("seed", None)
8384
auth_claims = context.get("auth_claims", {})
8485
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
8586
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
@@ -148,6 +149,7 @@ async def run(
148149
temperature=overrides.get("temperature", 0.3),
149150
max_tokens=response_token_limit,
150151
n=1,
152+
seed=seed,
151153
)
152154
).model_dump()
153155

app/frontend/src/api/models.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export type ChatAppRequestOverrides = {
2121
semantic_ranker?: boolean;
2222
semantic_captions?: boolean;
2323
exclude_category?: string;
24+
seed?: number;
2425
top?: number;
2526
temperature?: number;
2627
minimum_search_score?: number;

app/frontend/src/i18n/tooltips.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ export const toolTipText = {
55
"Overrides the prompt used to generate the answer based on the question and search results. To append to existing prompt instead of replace whole prompt, start your prompt with '>>>'.",
66
temperature:
77
"Sets the temperature of the request to the LLM that generates the answer. Higher temperatures result in more creative responses, but they may be less grounded.",
8+
seed: "Sets a seed to improve the reproducibility of the model's responses. The seed can be any integer.",
89
searchScore:
910
"Sets a minimum score for search results coming back from Azure AI search. The score range depends on whether you're using hybrid (default), vectors only, or text only.",
1011
rerankerScore:

app/frontend/src/pages/ask/Ask.tsx

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export function Component(): JSX.Element {
2727
const [promptTemplatePrefix, setPromptTemplatePrefix] = useState<string>("");
2828
const [promptTemplateSuffix, setPromptTemplateSuffix] = useState<string>("");
2929
const [temperature, setTemperature] = useState<number>(0.3);
30+
const [seed, setSeed] = useState<number | null>(null);
3031
const [minimumRerankerScore, setMinimumRerankerScore] = useState<number>(0);
3132
const [minimumSearchScore, setMinimumSearchScore] = useState<number>(0);
3233
const [retrievalMode, setRetrievalMode] = useState<RetrievalMode>(RetrievalMode.Hybrid);
@@ -124,7 +125,8 @@ export function Component(): JSX.Element {
124125
use_groups_security_filter: useGroupsSecurityFilter,
125126
vector_fields: vectorFieldList,
126127
use_gpt4v: useGPT4V,
127-
gpt4v_input: gpt4vInput
128+
gpt4v_input: gpt4vInput,
129+
...(seed !== null ? { seed: seed } : {})
128130
}
129131
},
130132
// AI Chat Protocol: Client must pass on any session state received from the server
@@ -148,6 +150,10 @@ export function Component(): JSX.Element {
148150
setTemperature(parseFloat(newValue || "0"));
149151
};
150152

153+
const onSeedChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
154+
setSeed(parseInt(newValue || ""));
155+
};
156+
151157
const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
152158
setMinimumSearchScore(parseFloat(newValue || "0"));
153159
};
@@ -206,6 +212,8 @@ export function Component(): JSX.Element {
206212
const promptTemplateFieldId = useId("promptTemplateField");
207213
const temperatureId = useId("temperature");
208214
const temperatureFieldId = useId("temperatureField");
215+
const seedId = useId("seed");
216+
const seedFieldId = useId("seedField");
209217
const searchScoreId = useId("searchScore");
210218
const searchScoreFieldId = useId("searchScoreField");
211219
const rerankerScoreId = useId("rerankerScore");
@@ -314,6 +322,19 @@ export function Component(): JSX.Element {
314322
)}
315323
/>
316324

325+
<TextField
326+
id={seedFieldId}
327+
className={styles.chatSettingsSeparator}
328+
label="Seed"
329+
type="text"
330+
defaultValue={seed?.toString() || ""}
331+
onChange={onSeedChange}
332+
aria-labelledby={seedId}
333+
onRenderLabel={(props: ITextFieldProps | undefined) => (
334+
<HelpCallout labelId={seedId} fieldId={seedFieldId} helpText={toolTipText.seed} label={props?.label} />
335+
)}
336+
/>
337+
317338
<TextField
318339
id={searchScoreFieldId}
319340
className={styles.chatSettingsSeparator}

app/frontend/src/pages/chat/Chat.tsx

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ const Chat = () => {
3939
const [isConfigPanelOpen, setIsConfigPanelOpen] = useState(false);
4040
const [promptTemplate, setPromptTemplate] = useState<string>("");
4141
const [temperature, setTemperature] = useState<number>(0.3);
42+
const [seed, setSeed] = useState<number | null>(null);
4243
const [minimumRerankerScore, setMinimumRerankerScore] = useState<number>(0);
4344
const [minimumSearchScore, setMinimumSearchScore] = useState<number>(0);
4445
const [retrieveCount, setRetrieveCount] = useState<number>(3);
@@ -173,7 +174,8 @@ const Chat = () => {
173174
use_groups_security_filter: useGroupsSecurityFilter,
174175
vector_fields: vectorFieldList,
175176
use_gpt4v: useGPT4V,
176-
gpt4v_input: gpt4vInput
177+
gpt4v_input: gpt4vInput,
178+
...(seed !== null ? { seed: seed } : {})
177179
}
178180
},
179181
// AI Chat Protocol: Client must pass on any session state received from the server
@@ -239,6 +241,10 @@ const Chat = () => {
239241
setTemperature(parseFloat(newValue || "0"));
240242
};
241243

244+
const onSeedChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
245+
setSeed(parseInt(newValue || ""));
246+
};
247+
242248
const onMinimumSearchScoreChange = (_ev?: React.SyntheticEvent<HTMLElement, Event>, newValue?: string) => {
243249
setMinimumSearchScore(parseFloat(newValue || "0"));
244250
};
@@ -309,6 +315,8 @@ const Chat = () => {
309315
const promptTemplateFieldId = useId("promptTemplateField");
310316
const temperatureId = useId("temperature");
311317
const temperatureFieldId = useId("temperatureField");
318+
const seedId = useId("seed");
319+
const seedFieldId = useId("seedField");
312320
const searchScoreId = useId("searchScore");
313321
const searchScoreFieldId = useId("searchScoreField");
314322
const rerankerScoreId = useId("rerankerScore");
@@ -478,6 +486,19 @@ const Chat = () => {
478486
)}
479487
/>
480488

489+
<TextField
490+
id={seedFieldId}
491+
className={styles.chatSettingsSeparator}
492+
label="Seed"
493+
type="text"
494+
defaultValue={seed?.toString() || ""}
495+
onChange={onSeedChange}
496+
aria-labelledby={seedId}
497+
onRenderLabel={(props: ITextFieldProps | undefined) => (
498+
<HelpCallout labelId={seedId} fieldId={seedFieldId} helpText={toolTipText.seed} label={props?.label} />
499+
)}
500+
/>
501+
481502
<TextField
482503
id={searchScoreFieldId}
483504
className={styles.chatSettingsSeparator}

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ async def __anext__(self):
180180
raise StopAsyncIteration
181181

182182
async def mock_acreate(*args, **kwargs):
183+
# The only two possible values for seed:
184+
assert kwargs.get("seed") is None or kwargs.get("seed") == 42
185+
183186
messages = kwargs["messages"]
184187
last_question = messages[-1]["content"]
185188
if last_question == "Generate search query for: What is the capital of France?":
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"context": {
3+
"data_points": {
4+
"text": [
5+
"Benefit_Options-2.pdf: There is a whistleblower policy."
6+
]
7+
},
8+
"thoughts": [
9+
{
10+
"description": [
11+
"{'role': 'system', 'content': \"Below is a history of the conversation so far, and a new question asked by the user that needs to be answered by searching in a knowledge base.\\n You have access to Azure AI Search index with 100's of documents.\\n Generate a search query based on the conversation and the new question.\\n Do not include cited source filenames and document names e.g info.txt or doc.pdf in the search query terms.\\n Do not include any text inside [] or <<>> in the search query terms.\\n Do not include any special characters like '+'.\\n If the question is not in English, translate the question to English before generating the search query.\\n If you cannot generate a search query, return just the number 0.\\n \"}",
12+
"{'role': 'user', 'content': 'How did crypto do last year?'}",
13+
"{'role': 'assistant', 'content': 'Summarize Cryptocurrency Market Dynamics from last year'}",
14+
"{'role': 'user', 'content': 'What are my health plans?'}",
15+
"{'role': 'assistant', 'content': 'Show available health plans'}",
16+
"{'role': 'user', 'content': 'Generate search query for: What is the capital of France?'}"
17+
],
18+
"props": {
19+
"model": "gpt-35-turbo"
20+
},
21+
"title": "Prompt to generate search query"
22+
},
23+
{
24+
"description": "capital of France",
25+
"props": {
26+
"filter": null,
27+
"top": 3,
28+
"use_semantic_captions": false,
29+
"use_semantic_ranker": false,
30+
"use_text_search": true,
31+
"use_vector_search": true
32+
},
33+
"title": "Search using generated search query"
34+
},
35+
{
36+
"description": [
37+
{
38+
"captions": [
39+
{
40+
"additional_properties": {},
41+
"highlights": [],
42+
"text": "Caption: A whistleblower policy."
43+
}
44+
],
45+
"category": null,
46+
"content": "There is a whistleblower policy.",
47+
"embedding": null,
48+
"groups": null,
49+
"id": "file-Benefit_Options_pdf-42656E656669745F4F7074696F6E732E706466-page-2",
50+
"imageEmbedding": null,
51+
"oids": null,
52+
"reranker_score": 3.4577205181121826,
53+
"score": 0.03279569745063782,
54+
"sourcefile": "Benefit_Options.pdf",
55+
"sourcepage": "Benefit_Options-2.pdf"
56+
}
57+
],
58+
"props": null,
59+
"title": "Search results"
60+
},
61+
{
62+
"description": [
63+
"{'role': 'system', 'content': \"Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.\\n Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.\\n For tabular information return it as an html table. Do not return markdown format. If the question is not in English, answer in the language used in the question.\\n Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, for example [info1.txt]. Don't combine sources, list each source separately, for example [info1.txt][info2.pdf].\\n \\n \\n \"}",
64+
"{'role': 'user', 'content': 'What is the capital of France?\\n\\nSources:\\nBenefit_Options-2.pdf: There is a whistleblower policy.'}"
65+
],
66+
"props": {
67+
"model": "gpt-35-turbo"
68+
},
69+
"title": "Prompt to generate answer"
70+
}
71+
]
72+
},
73+
"finish_reason": "stop",
74+
"index": 0,
75+
"logprobs": null,
76+
"message": {
77+
"content": "The capital of France is Paris. [Benefit_Options-2.pdf].",
78+
"function_call": null,
79+
"role": "assistant",
80+
"tool_calls": null
81+
},
82+
"session_state": null
83+
}

0 commit comments

Comments
 (0)