Skip to content

Commit adc9f9f

Browse files
committed
add o3-deep-research and o4-mini-deep-research
1 parent 4c9f6e2 commit adc9f9f

File tree

11 files changed

+370
-54
lines changed

11 files changed

+370
-54
lines changed

configs/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,13 @@
6161
analyzer_model_id = "o3",
6262
predict_model_id = "veo3-predict",
6363
fetch_model_id = "veo3-fetch",
64+
)
65+
66+
file_reader_tool_config = dict(
67+
type="file_reader_tool"
68+
)
69+
70+
oai_deep_research_tool_config = dict(
71+
type="oai_deep_research_tool",
72+
model_id = "o3-deep-research",
6473
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
_base_ = './base.py'
2+
3+
# General Config
4+
tag = "oai_deep_research-o3"
5+
concurrency = 4
6+
workdir = "workdir"
7+
log_path = "log.txt"
8+
save_path = "dra.jsonl"
9+
use_local_proxy = True # True for local proxy, False for public proxy
10+
11+
use_hierarchical_agent = False
12+
13+
dataset = dict(
14+
type="gaia_dataset",
15+
name="2023_all",
16+
path="data/GAIA",
17+
split="test",
18+
)
19+
20+
oai_deep_research_tool_config = dict(
21+
type="oai_deep_research_tool",
22+
model_id = "o3-deep-research",
23+
)
24+
25+
oai_deep_research_agent_config = dict(
26+
type="general_agent",
27+
name="oai_deep_research_agent",
28+
model_id="gpt-4.1",
29+
description = "A general agent that can perform deep research using openai's deep research capabilities.",
30+
max_steps = 20,
31+
template_path = "src/agent/general_agent/prompts/general_agent.yaml",
32+
provide_run_summary = True,
33+
tools = ["oai_deep_research_tool", "deep_analyzer_tool"],
34+
mcp_tools = [],
35+
)
36+
37+
agent_config = oai_deep_research_agent_config

examples/run_oai_deep_research.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import warnings
2+
warnings.simplefilter("ignore", DeprecationWarning)
3+
4+
import os
5+
import sys
6+
from pathlib import Path
7+
import pandas as pd
8+
from typing import List
9+
import json
10+
from datetime import datetime
11+
import asyncio
12+
import threading
13+
import argparse
14+
from mmengine import DictAction
15+
16+
root = str(Path(__file__).resolve().parents[1])
17+
sys.path.append(root)
18+
19+
from src.logger import logger
20+
from src.config import config
21+
from src.models import model_manager
22+
from src.metric import question_scorer
23+
from src.agent import create_agent, prepare_response
24+
from src.registry import DATASET
25+
from src.tools import FileReaderTool
26+
27+
append_answer_lock = threading.Lock()
28+
29+
def append_answer(entry: dict, jsonl_file: str) -> None:
30+
jsonl_file = Path(jsonl_file)
31+
jsonl_file.parent.mkdir(parents=True, exist_ok=True)
32+
with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp:
33+
fp.write(json.dumps(entry) + "\n")
34+
assert os.path.exists(jsonl_file), "File not found!"
35+
print("Answer exported to file:", jsonl_file.resolve())
36+
37+
def filter_answers(answers_file):
38+
answer_df = pd.read_json(answers_file, lines=True)
39+
40+
filttered_df = []
41+
for row in answer_df.iterrows():
42+
row = row[1]
43+
44+
prediction = row['prediction']
45+
truth = row['true_answer']
46+
47+
# If the prediction is "Unable to determine", we set it to None
48+
if str(prediction) == "Unable to determine":
49+
prediction = None
50+
51+
# Processing the test dataset that not contains the true answer
52+
if truth == "?":
53+
if prediction is not None:
54+
filttered_df.append(row)
55+
# Processing the validation dataset that contains the true answer
56+
else:
57+
if prediction is not None:
58+
prediction = str(prediction)
59+
score = question_scorer(prediction, truth)
60+
if score:
61+
filttered_df.append(row)
62+
63+
filttered_df = pd.DataFrame(filttered_df)
64+
filttered_df.to_json(answers_file, lines=True, orient='records')
65+
66+
logger.info(f"Previous answers filtered! {len(answer_df)} -> {len(filttered_df)}")
67+
68+
def get_tasks_to_run(answers_file, dataset) -> List[dict]:
69+
70+
data = dataset.data
71+
72+
logger.info(f"Loading answers from {answers_file}...")
73+
try:
74+
if os.path.exists(answers_file):
75+
logger.info("Filtering answers starting.")
76+
filter_answers(answers_file)
77+
logger.info("Filtering answers ending.")
78+
79+
df = pd.read_json(answers_file, lines=True)
80+
if "task_id" not in df.columns:
81+
logger.warning(f"Answers file {answers_file} does not contain 'task_id' column. Please check the file format.")
82+
return []
83+
done_questions = df["task_id"].tolist()
84+
logger.info(f"Found {len(done_questions)} previous results!")
85+
else:
86+
done_questions = []
87+
except Exception as e:
88+
logger.warning("Error when loading records: ", e)
89+
logger.warning("No usable records! ▶️ Starting new.")
90+
done_questions = []
91+
return [line for line in data.to_dict(orient="records") if line["task_id"] not in done_questions]
92+
93+
async def answer_single_question(config, example):
94+
95+
try:
96+
agent = await create_agent(config)
97+
logger.visualize_agent_tree(agent)
98+
99+
logger.info(f"Task Id: {example['task_id']}, Final Answer: {example['true_answer']}")
100+
101+
augmented_question = example["question"]
102+
file_reader_tool = FileReaderTool(text_limit=50000)
103+
104+
if example["file_name"]:
105+
prompt_use_files = "\n\nTo solve the task above, you will have to use these attached files:\n"
106+
file_description = f" - Attached file: {example['file_name']}"
107+
file_text = await file_reader_tool.forward(file_path=example["file_name"])
108+
if file_text.error:
109+
logger.warning(f"Error reading file {example['file_name']}: {file_text.error}")
110+
file_text = "Unable to read the file."
111+
else:
112+
file_text = file_text.output
113+
file_description += f"\n{file_text}"
114+
prompt_use_files += file_description
115+
augmented_question += prompt_use_files
116+
117+
start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
118+
119+
# Run agent 🚀
120+
final_result = await agent.run(task=augmented_question)
121+
122+
agent_memory = await agent.write_memory_to_messages(summary_mode=True)
123+
124+
final_result = await prepare_response(augmented_question,
125+
agent_memory,
126+
reformulation_model=model_manager.registed_models["gpt-4.1"])
127+
128+
output = str(final_result)
129+
for memory_step in agent.memory.steps:
130+
memory_step.model_input_messages = None
131+
intermediate_steps = [str(step) for step in agent.memory.steps]
132+
133+
# Check for parsing errors which indicate the LLM failed to follow the required format
134+
parsing_error = True if any(["AgentParsingError" in step for step in intermediate_steps]) else False
135+
136+
# check if iteration limit exceeded
137+
iteration_limit_exceeded = True if "Agent stopped due to iteration limit or time limit." in output else False
138+
raised_exception = False
139+
140+
except Exception as e:
141+
logger.info("Error on ", augmented_question, e)
142+
output = None
143+
intermediate_steps = []
144+
parsing_error = False
145+
iteration_limit_exceeded = False
146+
exception = e
147+
raised_exception = True
148+
end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
149+
annotated_example = {
150+
"agent_name": config.agent_config.name,
151+
"question": example["question"],
152+
"augmented_question": augmented_question,
153+
"prediction": output,
154+
"intermediate_steps": intermediate_steps,
155+
"parsing_error": parsing_error,
156+
"iteration_limit_exceeded": iteration_limit_exceeded,
157+
"agent_error": str(exception) if raised_exception else None,
158+
"start_time": start_time,
159+
"end_time": end_time,
160+
"task": example["task"],
161+
"task_id": example["task_id"],
162+
"true_answer": example["true_answer"],
163+
}
164+
append_answer(annotated_example, config.save_path)
165+
166+
def parse_args():
167+
parser = argparse.ArgumentParser(description='main')
168+
parser.add_argument("--config", default=os.path.join(root, "configs", "config_oai_deep_research.py"), help="config file path")
169+
170+
parser.add_argument(
171+
'--cfg-options',
172+
nargs='+',
173+
action=DictAction,
174+
help='override some settings in the used config, the key-value pair '
175+
'in xxx=yyy format will be merged into config file. If the value to '
176+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
177+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
178+
'Note that the quotation marks are necessary and that no white space '
179+
'is allowed.')
180+
args = parser.parse_args()
181+
return args
182+
183+
async def main():
184+
# Parse command line arguments
185+
args = parse_args()
186+
187+
# Initialize the configuration
188+
config.init_config(args.config, args)
189+
190+
# Initialize the logger
191+
logger.init_logger(log_path=config.log_path)
192+
logger.info(f"| Logger initialized at: {config.log_path}")
193+
logger.info(f"| Config:\n{config.pretty_text}")
194+
195+
# Registed models
196+
model_manager.init_models(use_local_proxy=True)
197+
logger.info("| Registed models: %s", ", ".join(model_manager.registed_models.keys()))
198+
199+
# Load dataset
200+
dataset = DATASET.build(config.dataset)
201+
logger.info(f"| Loaded dataset: {len(dataset)} examples.")
202+
203+
# Load answers
204+
tasks_to_run = get_tasks_to_run(config.save_path, dataset)
205+
tasks_to_run = [task for task in tasks_to_run[:-1]] # Remove the last task which is a test example
206+
logger.info(f"| Loaded {len(tasks_to_run)} tasks to run.")
207+
208+
await answer_single_question(config, [task for task in tasks_to_run if task["task_id"] == "16cf70d8-9263-4eb0-a8a9-5eb91a23b462"][0]) # Run test example first
209+
exit()
210+
211+
# Run tasks
212+
batch_size = getattr(config, "concurrency", 4)
213+
for i in range(0, len(tasks_to_run), batch_size):
214+
batch = tasks_to_run[i:min(i + batch_size, len(tasks_to_run))]
215+
await asyncio.gather(*[answer_single_question(config, task) for task in batch])
216+
logger.info(f"| Batch {i // batch_size + 1} done.")
217+
218+
if __name__ == '__main__':
219+
asyncio.run(main())

src/mcp/local/mcp_tools_registry.json

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -456,32 +456,6 @@
456456
"usage_count": 2,
457457
"last_used": "2025-08-01T17:01:50.743059"
458458
},
459-
{
460-
"name": "extract_colored_numbers_from_image",
461-
"description": "Extracts numbers from an image and categorizes them into two lists based on their color: red or green.",
462-
"function": null,
463-
"metadata": {
464-
"name": "extract_colored_numbers_from_image",
465-
"description": "Extracts numbers from an image and categorizes them into two lists based on their color: red or green.",
466-
"requires": "cv2, pytesseract, numpy, typing",
467-
"args": [
468-
"image_path (str): Path to the input image file.",
469-
"red_color_lower_bound (tuple): Lower bound for red color in HSV format (H, S, V).",
470-
"red_color_upper_bound (tuple): Upper bound for red color in HSV format (H, S, V).",
471-
"green_color_lower_bound (tuple): Lower bound for green color in HSV format (H, S, V).",
472-
"green_color_upper_bound (tuple): Upper bound for green color in HSV format (H, S, V).",
473-
"ocr_config (str): Configuration string for Tesseract OCR.",
474-
"min_contour_area (int): The minimum area (in pixels) for a contour to be considered a number."
475-
],
476-
"returns": [
477-
"result (dict): A dictionary with 'red_numbers' and 'green_numbers' as keys, containing their respective lists of extracted integers."
478-
]
479-
},
480-
"script_content": "```python\n# MCP Name: extract_colored_numbers_from_image\n# Description: Extracts numbers from an image and categorizes them into two lists based on their color: red or green.\n# Arguments:\n# image_path (str): Path to the input image file.\n# red_color_lower_bound (tuple): Lower bound for red color in HSV format (H, S, V).\n# red_color_upper_bound (tuple): Upper bound for red color in HSV format (H, S, V).\n# green_color_lower_bound (tuple): Lower bound for green color in HSV format (H, S, V).\n# green_color_upper_bound (tuple): Upper bound for green color in HSV format (H, S, V).\n# ocr_config (str): Configuration string for Tesseract OCR.\n# min_contour_area (int): The minimum area (in pixels) for a contour to be considered a number.\n# Returns:\n# result (dict): A dictionary with 'red_numbers' and 'green_numbers' as keys, containing their respective lists of extracted integers.\n# Requires: cv2, pytesseract, numpy, typing\n\nimport cv2\nimport pytesseract\nimport numpy as np\nfrom typing import Tuple, List, Dict\n\ndef extract_colored_numbers_from_image(\n image_path: str,\n red_color_lower_bound: Tuple[int, int, int],\n red_color_upper_bound: Tuple[int, int, int],\n green_color_lower_bound: Tuple[int, int, int],\n green_color_upper_bound: Tuple[int, int, int],\n ocr_config: str,\n min_contour_area: int\n) -> Dict[str, List[int]]:\n \"\"\"\n Extracts numbers from an image and categorizes them into two lists based on their color: red or green.\n\n This function reads an image, converts it to the HSV color space, and then creates binary masks\n for the specified red and green color ranges. It finds contours in these masks, filters them by area,\n and then performs Optical Character Recognition (OCR) on each valid contour to extract the numbers.\n\n Args:\n image_path (str): Path to the input image file.\n red_color_lower_bound (Tuple[int, int, int]): Lower bound for red color in HSV format (H, S, V).\n red_color_upper_bound (Tuple[int, int, int]): Upper bound for red color in HSV format (H, S, V).\n green_color_lower_bound (Tuple[int, int, int]): Lower bound for green color in HSV format (H, S, V).\n green_color_upper_bound (Tuple[int, int, int]): Upper bound for green color in HSV format (H, S, V).\n ocr_config (str): Configuration string for Tesseract OCR (e.g., '--psm 10 -c tessedit_char_whitelist=0123456789').\n min_contour_area (int): The minimum area (in pixels) for a contour to be considered a number, used to filter out noise.\n\n Returns:\n Dict[str, List[int]]: A dictionary with two keys, 'red_numbers' and 'green_numbers', each containing a list of the integers extracted for that color.\n \"\"\"\n try:\n image = cv2.imread(image_path)\n if image is None:\n raise FileNotFoundError(f\"Image not found at path: {image_path}\")\n\n hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)\n\n # Create masks for red and green colors\n red_mask = cv2.inRange(hsv_image, red_color_lower_bound, red_color_upper_bound)\n green_mask = cv2.inRange(hsv_image, green_color_lower_bound, green_color_upper_bound)\n\n def _extract_from_mask(mask: np.ndarray) -> List[int]:\n \"\"\"Helper function to extract numbers from a given color mask.\"\"\"\n numbers = []\n contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n\n # Sort contours top-to-bottom, then left-to-right for consistent order\n if contours:\n bounding_boxes = [cv2.boundingRect(c) for c in contours]\n (contours, _) = zip(*sorted(zip(contours, bounding_boxes),\n key=lambda b: (b[1][1], b[1][0])))\n\n for contour in contours:\n if cv2.contourArea(contour) < min_contour_area:\n continue\n\n x, y, w, h = cv2.boundingRect(contour)\n \n # Add padding to ROI to prevent cropping number edges\n padding = 5\n roi = image[max(0, y - padding):min(image.shape[0], y + h + padding), \n max(0, x - padding):min(image.shape[1], x + w + padding)]\n\n if roi.size == 0:\n continue\n \n # Pre-process ROI for better OCR accuracy\n gray_roi = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)\n # Apply thresholding to get a clear black and white image\n _, thresh_roi = cv2.threshold(gray_roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)\n\n text = pytesseract.image_to_string(thresh_roi, config=ocr_config).strip()\n \n cleaned_text = ''.join(filter(str.isdigit, text))\n if cleaned_text:\n numbers.append(int(cleaned_text))\n return numbers\n\n red_numbers = _extract_from_mask(red_mask)\n green_numbers = _extract_from_mask(green_mask)\n\n return {\n \"red_numbers\": red_numbers,\n \"green_numbers\": green_numbers\n }\n except Exception as e:\n # In a real application, you might want to log the error.\n # For this tool, returning a descriptive error string is sufficient.\n return {\"error\": f\"An error occurred: {str(e)}\"}\n```",
481-
"created_at": "2025-08-01T18:06:32.493499",
482-
"usage_count": 4,
483-
"last_used": "2025-08-02T06:49:08.127031"
484-
},
485459
{
486460
"name": "calculate_deviation_average",
487461
"description": "Takes two lists of numbers (red and green). It calculates the population standard deviation for the red numbers and the sample standard deviation for the green numbers using Python's 'statistics' module. It then returns the average of these two deviation values, rounded to three decimal points.",

src/mcp/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ async def register_tools(script_info_path):
6868
logger.info(f"Script info file not found: {script_info_path}")
6969
except json.JSONDecodeError:
7070
logger.error(f"Error decoding JSON from script info file: {script_info_path}")
71+
except Exception as e:
72+
logger.error(f"An unexpected error occurred while registering tools: {e}")
7173

7274
logger.info("All tools registered successfully.")
7375

src/models/models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,14 @@ def _register_openai_models(self, use_local_proxy: bool = False):
154154
# deep research
155155
model_name = "o3-deep-research"
156156
model_id = "o3-deep-research"
157-
client = AsyncOpenAI(
157+
158+
model = RestfulResponseModel(
159+
api_base=self._check_local_api_base(local_api_base_name="SKYWORK_SHUBIAOBIAO_API_BASE",
160+
remote_api_base_name="OPENAI_API_BASE"),
158161
api_key=api_key,
159-
base_url=self._check_local_api_base(local_api_base_name="SKYWORK_API_BASE",
160-
remote_api_base_name="SKYWORK_API_BASE"),
161-
http_client=ASYNC_HTTP_CLIENT,
162-
)
163-
model = LiteLLMModel(
162+
api_type="responses",
164163
model_id=model_id,
165-
http_client=client,
164+
http_client=HTTP_CLIENT,
166165
custom_role_conversions=custom_role_conversions,
167166
)
168167
self.registed_models[model_name] = model

0 commit comments

Comments
 (0)