Skip to content

Commit 7e8057e

Browse files
authored
fix: reimplemented search-o1 (#158)
1 parent 72af506 commit 7e8057e

File tree

4 files changed

+187
-36
lines changed

4 files changed

+187
-36
lines changed

examples/search_o1.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ pipeline:
1515
- benchmark.get_data
1616
- retriever.retriever_init
1717
- generation.generation_init
18+
- custom.search_o1_init_list
1819
- prompt.search_o1_init
1920
- generation.generate
2021
- loop:
@@ -29,8 +30,12 @@ pipeline:
2930
- retriever.retriever_search:
3031
input:
3132
query_list: extract_query_list
32-
- prompt.searcho1_reasoning_indocument
33+
- custom.search_o1_reasoning_extract
34+
- custom.search_o1_combine_list
35+
- prompt.search_o1_reasoning_indocument
3336
- generation.generate
37+
- custom.search_o1_extract_final_information
38+
- custom.search_o1_combine_final_information
3439
- prompt.search_o1_insert
3540
- generation.generate
3641
stop: []

servers/custom/src/custom.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,46 @@ def ircot_extract_ans(ans_ls: List[str]) -> Dict[str, List[str]]:
119119
return {"pred_ls": ret}
120120

121121

122+
@app.tool(output="q_ls->total_subq_list,total_reason_list,total_final_info_list")
123+
def search_o1_init_list(q_ls: List[str]) -> Dict[str, List[Any]]:
124+
n = len(q_ls)
125+
126+
return {
127+
"total_subq_list": [["<PAD>"] for _ in range(n)],
128+
"total_reason_list": [["<PAD>"] for _ in range(n)],
129+
"total_final_info_list": [["<PAD>"] for _ in range(n)],
130+
}
131+
132+
@app.tool(
133+
output="total_subq_list, extract_query_list, total_reason_list, extract_reason_list"
134+
"->total_subq_list, total_reason_list"
135+
)
136+
def search_o1_combine_list(
137+
total_subq_list: List[List[Any]],
138+
extract_query_list: List[str],
139+
total_reason_list: List[List[Any]],
140+
extract_reason_list: List[str],
141+
) -> Dict[str, List[Any]]:
142+
143+
PAD = "<PAD>"
144+
145+
for q, bucket in zip(extract_query_list, total_subq_list):
146+
if len(bucket) == 1 and bucket[0] == PAD:
147+
bucket[0] = q
148+
else:
149+
bucket.append(q)
150+
151+
for c, bucket in zip(extract_reason_list, total_reason_list):
152+
if len(bucket) == 1 and bucket[0] == PAD:
153+
bucket[0] = c
154+
else:
155+
bucket.append(c)
156+
157+
return {
158+
"total_subq_list": total_subq_list,
159+
"total_reason_list": total_reason_list,
160+
}
161+
122162
@app.tool(output="ans_ls->extract_query_list")
123163
def search_o1_query_extract(ans_ls: List[str]) -> Dict[str, List[str]]:
124164
import re
@@ -139,6 +179,58 @@ def get_query(text):
139179

140180
return {"extract_query_list": query}
141181

182+
@app.tool(output="ans_ls->extract_reason_list")
183+
def search_o1_reasoning_extract(ans_ls: List[str]) -> Dict[str, List[str]]:
184+
185+
BEGIN = "<|begin_search_query|>"
186+
187+
def get_content_before(text):
188+
if BEGIN not in text:
189+
return text.strip()
190+
191+
192+
return text.split(BEGIN, 1)[0].strip()
193+
194+
content_list = [get_content_before(answer) for answer in ans_ls]
195+
196+
return {"extract_reason_list": content_list}
197+
198+
@app.tool(output="ans_ls->extract_final_infor_list")
199+
def search_o1_extract_final_information(ans_ls: List[str]) -> Dict[str, List[str]]:
200+
201+
BEGIN = "**Final Information**"
202+
203+
def get_content_after(text):
204+
if BEGIN not in text:
205+
return ""
206+
207+
return BEGIN + "\n" + text.split(BEGIN, 1)[1].strip()
208+
209+
content_list = [get_content_after(answer) for answer in ans_ls]
210+
211+
return {"extract_final_infor_list": content_list}
212+
213+
@app.tool(output="total_final_info_list, extract_final_infor_list->total_final_info_list")
214+
def search_o1_combine_final_information(
215+
total_final_info_list: List[List[str]],
216+
extract_final_infor_list: List[str],
217+
) -> Dict[str, List[Any]]:
218+
219+
PAD = "<PAD>"
220+
221+
for c, bucket in zip(extract_final_infor_list, total_final_info_list):
222+
if len(bucket) == 1 and bucket[0] == PAD:
223+
bucket[0] = c
224+
else:
225+
bucket.append(c)
226+
227+
app.logger.warning(f"len total_final_info_list: {len(total_final_info_list)}")
228+
app.logger.warning(f"total_final_info_list: {total_final_info_list}")
229+
230+
return {
231+
"total_final_info_list": total_final_info_list,
232+
}
233+
142234
@app.tool(output="temp_psg,ret_psg->ret_psg")
143235
def merge_passages(
144236
temp_psg: List[str | Any],

servers/prompt/src/prompt.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -288,41 +288,71 @@ def search_o1_init(
288288
ret.append(p)
289289
return ret
290290

291-
292291
@app.prompt(
293-
output="prompt_ls,extract_query_list,ret_psg,searcho1_refine_template->prompt_ls"
292+
output="extract_query_list, ret_psg, total_reason_list, searcho1_refine_template -> prompt_ls"
294293
)
295-
def searcho1_reasoning_indocument(
296-
prompt_ls: List[PromptMessage],
297-
extract_query_list: List[str],
298-
ret_psg: List[str | Any],
294+
def search_o1_reasoning_indocument(
295+
extract_query_list: List[str],
296+
ret_psg: List[List[str]],
297+
total_reason_list: List[List[str]],
299298
template: str | Path,
300299
) -> List[PromptMessage]:
301300
template: Template = load_prompt_template(template)
302301
ret = []
303-
for prompt, squery, psg in zip(prompt_ls, extract_query_list, ret_psg):
304-
passage_text = "\n".join(psg)
305-
_pro = prompt.content.text
302+
303+
for squery, psg_list, history_steps in zip(extract_query_list, ret_psg, total_reason_list):
304+
305+
passage_text = "\n".join(psg_list)
306+
307+
if len(history_steps) <= 3:
308+
selected_history = history_steps[:]
309+
else:
310+
selected_history = [history_steps[0]] + history_steps[-3:]
311+
312+
formatted_history_parts = [
313+
f"Step {i}: {reason}"
314+
for i, reason in enumerate(selected_history, 1)
315+
]
316+
formatted_history_str = "\n\n".join(formatted_history_parts)
317+
306318
p = template.render(
307-
prev_reasoning=_pro, search_query=squery, document=passage_text
319+
prev_reasoning=formatted_history_str,
320+
search_query=squery,
321+
document=passage_text
308322
)
309323
ret.append(p)
310-
return ret
311324

325+
return ret
312326

313-
@app.prompt(output="prompt_ls,ans_ls->prompt_ls")
327+
@app.prompt(output="q_ls,total_subq_list,total_final_info_list,searcho1_reasoning_template->prompt_ls")
314328
def search_o1_insert(
315-
prompt_ls: List[PromptMessage],
316-
ans_ls: List[str],
329+
q_ls: List[str],
330+
total_subq_list: List[List[str]],
331+
total_final_info_list: List[List[str]],
332+
template: str | Path,
317333
) -> List[PromptMessage]:
334+
template: Template = load_prompt_template(template)
335+
prompt_ls = []
336+
for q in q_ls:
337+
p = template.render(question=q)
338+
prompt_ls.append(p)
339+
318340
ret = []
319-
for prompt, ans in zip(prompt_ls, ans_ls):
320-
_pro = prompt.content.text
321-
p = _pro + "<|begin_search_result|>" + ans + "<|end_search_result|>"
322-
ret.append(p)
341+
for prompt, sub_queries, sub_reasons in zip(prompt_ls, total_subq_list, total_final_info_list):
342+
343+
344+
for query, reason in zip(sub_queries, sub_reasons):
345+
part = (
346+
"<|begin_search_query|>" + str(query) + "<|end_search_query|>" +
347+
'\n' +
348+
"<|begin_search_result|>" + str(reason) + "<|end_search_result|>"
349+
)
350+
prompt += part
351+
352+
ret.append(prompt)
353+
323354
return ret
324355

325-
326356
# prompt for loop and branch demo
327357
@app.prompt(output="q_ls,ret_psg,gen_subq_template->prompt_ls")
328358
def gen_subq(

servers/router/src/router.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from typing import List, Dict
2-
1+
from typing import List, Dict, Any
32
from ultrarag.server import UltraRAG_MCP_Server
43

54

@@ -105,24 +104,49 @@ def get_eos(text):
105104
return {"ans_ls": ans_ls}
106105

107106

108-
@app.tool(output="ans_ls->ans_ls")
109-
def search_o1_check(ans_ls: List[str]) -> Dict[str, List[Dict[str, str]]]:
110-
def get_eos(text):
107+
@app.tool(
108+
output=("ans_ls,q_ls,total_subq_list,total_reason_list,total_final_info_list->ans_ls,q_ls,total_subq_list,total_reason_list,total_final_info_list")
109+
)
110+
def search_o1_check(
111+
ans_ls: List[str],
112+
q_ls: List[str],
113+
total_subq_list: List[List[Any]],
114+
total_reason_list: List[List[Any]],
115+
total_final_info_list: List[List[Any]],
116+
) -> Dict[str, List[Dict[str, Any]]]:
117+
118+
def get_eos(text: str) -> bool:
111119
if "<|im_end|>" in text:
112-
return True
120+
return True
113121
elif "<|end_search_query|>" in text:
114-
return False
122+
return False
115123
else:
116-
return True
124+
return True
117125

118-
ans_ls = [
119-
{
120-
"data": answer,
121-
"state": "stop" if get_eos(answer) else "retrieve",
122-
}
123-
for answer in ans_ls
124-
]
125-
return {"ans_ls": ans_ls}
126+
ans_out: List[Dict[str, Any]] = []
127+
q_out: List[Dict[str, Any]] = []
128+
subq_out: List[Dict[str, Any]] = []
129+
reason_out: List[Dict[str, Any]] = []
130+
info_out: List[Dict[str, Any]] = []
131+
132+
for ans, q, subq, reason, info in zip(
133+
ans_ls, q_ls, total_subq_list, total_reason_list, total_final_info_list
134+
):
135+
state = "stop" if get_eos(ans) else "retrieve"
136+
137+
ans_out.append({"data": ans, "state": state})
138+
q_out.append({"data": q, "state": state})
139+
subq_out.append({"data": subq, "state": state})
140+
reason_out.append({"data": reason, "state": state})
141+
info_out.append({"data": info, "state": state})
142+
143+
return {
144+
"ans_ls": ans_out,
145+
"q_ls": q_out,
146+
"total_subq_list": subq_out,
147+
"total_reason_list": reason_out,
148+
"total_final_info_list": info_out,
149+
}
126150

127151

128152
@app.tool(output="ans_ls->ans_ls")

0 commit comments

Comments
 (0)