Skip to content

Commit 7d4d136

Browse files
authored
feat: planning task and megamind agent (#679)
1 parent 45baf9a commit 7d4d136

File tree

13 files changed

+1298
-55
lines changed

13 files changed

+1298
-55
lines changed

src/rai_bench/rai_bench/examples/manipulation_o3de.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
model_name=args.model_name,
3232
vendor=args.vendor,
3333
)
34-
3534
run_benchmark(
3635
llm=llm,
3736
out_dir=experiment_dir,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
import uuid
16+
from datetime import datetime
17+
from pathlib import Path
18+
19+
from rai.agents.langchain.core import (
20+
Executor,
21+
create_megamind,
22+
get_initial_megamind_state,
23+
)
24+
25+
from rai_bench import (
26+
define_benchmark_logger,
27+
)
28+
from rai_bench.tool_calling_agent.benchmark import ToolCallingAgentBenchmark
29+
from rai_bench.tool_calling_agent.interfaces import TaskArgs
30+
from rai_bench.tool_calling_agent.tasks.warehouse import SortingTask
31+
from rai_bench.utils import get_llm_for_benchmark
32+
33+
if __name__ == "__main__":
34+
now = datetime.now()
35+
out_dir = f"src/rai_bench/rai_bench/experiments/tool_calling/{now.strftime('%Y-%m-%d_%H-%M-%S')}"
36+
experiment_dir = Path(out_dir)
37+
experiment_dir.mkdir(parents=True, exist_ok=True)
38+
bench_logger = define_benchmark_logger(out_dir=experiment_dir, level=logging.DEBUG)
39+
40+
task = SortingTask(task_args=TaskArgs(extra_tool_calls=50))
41+
task.set_logger(bench_logger)
42+
43+
supervisor_name = "gpt-4o"
44+
45+
executor_name = "gpt-4o-mini"
46+
model_name = f"supervisor-{supervisor_name}_executor-{executor_name}"
47+
supervisor_llm = get_llm_for_benchmark(model_name=supervisor_name, vendor="openai")
48+
executor_llm = get_llm_for_benchmark(
49+
model_name=executor_name,
50+
vendor="openai",
51+
)
52+
53+
benchmark = ToolCallingAgentBenchmark(
54+
tasks=[task],
55+
logger=bench_logger,
56+
model_name=model_name,
57+
results_dir=experiment_dir,
58+
)
59+
manipulation_system_prompt = """You are a manipulation specialist robot agent.
60+
Your role is to handle object manipulation tasks including picking up and droping objects using provided tools.
61+
62+
Ask the VLM for objects detection and positions before perfomring any manipulation action.
63+
If VLM doesn't see objects that are objectives of the task, return this information, without proceeding"""
64+
65+
navigation_system_prompt = """You are a navigation specialist robot agent.
66+
Your role is to handle navigation tasks in space using provided tools.
67+
68+
After performing navigation action, always check your current position to ensure success"""
69+
70+
executors = [
71+
Executor(
72+
name="manipulation",
73+
llm=executor_llm,
74+
tools=task.manipulation_tools(),
75+
system_prompt=manipulation_system_prompt,
76+
),
77+
Executor(
78+
name="navigation",
79+
llm=executor_llm,
80+
tools=task.navigation_tools(),
81+
system_prompt=navigation_system_prompt,
82+
),
83+
]
84+
agent = create_megamind(
85+
megamind_llm=supervisor_llm,
86+
megamind_system_prompt=task.get_system_prompt(),
87+
executors=executors,
88+
task_planning_prompt=task.get_planning_prompt(),
89+
)
90+
91+
experiment_id = uuid.uuid4()
92+
benchmark.run_next(
93+
agent=agent,
94+
initial_state=get_initial_megamind_state(task=task.get_prompt()),
95+
experiment_id=experiment_id,
96+
)
97+
98+
bench_logger.info("===============================================================")
99+
bench_logger.info("ALL SCENARIOS DONE. BENCHMARK COMPLETED!")
100+
bench_logger.info("===============================================================")

src/rai_bench/rai_bench/tool_calling_agent/benchmark.py

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from typing import Iterator, List, Optional, Sequence, Tuple
2020

2121
from langchain_core.language_models import BaseChatModel
22-
from langchain_core.messages import BaseMessage
22+
from langchain_core.messages import AIMessage, BaseMessage
2323
from langchain_core.runnables.config import RunnableConfig
2424
from langgraph.errors import GraphRecursionError
2525
from langgraph.graph.state import CompiledStateGraph
2626
from rai.agents.langchain.core import (
2727
create_conversational_agent,
2828
)
29-
from rai.messages import HumanMultimodalMessage
29+
from rai.agents.langchain.core.react_agent import ReActAgentState
3030

3131
from rai_bench.agents import create_multimodal_to_tool_agent
3232
from rai_bench.base_benchmark import BaseBenchmark, TimeoutException
@@ -38,9 +38,6 @@
3838
TaskResult,
3939
ToolCallingAgentRunSummary,
4040
)
41-
from rai_bench.tool_calling_agent.tasks.spatial import (
42-
SpatialReasoningAgentTask,
43-
)
4441
from rai_bench.utils import get_llm_model_name
4542

4643

@@ -67,7 +64,12 @@ def __init__(
6764
self.tasks_results: List[TaskResult] = []
6865
self.csv_initialize(self.results_filename, TaskResult)
6966

70-
def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
67+
def run_next(
68+
self,
69+
agent: CompiledStateGraph,
70+
initial_state: ReActAgentState,
71+
experiment_id: uuid.UUID,
72+
) -> None:
7173
"""Runs the next task of the benchmark.
7274
7375
Parameters
@@ -87,14 +89,16 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
8789
)
8890
callbacks = self.score_tracing_handler.get_callbacks()
8991
run_id = uuid.uuid4()
90-
# NOTE (jmatejcz) recursion limit calculated as all_nodes_num -> one pass though whole node
91-
# plus (task.max_tool_calls_number-1 because the first pass is already added in)
92-
# times number of nodes - 2 because we dont cout start and end node
93-
# this can be to much for larger graphs that dont use all nodes on extra calls
94-
# in such ase adjust this value
95-
recurssion_limit = len(agent.get_graph().nodes) + (
96-
task.max_tool_calls_number - 1
97-
) * (len(agent.get_graph().nodes) - 2)
92+
# NOTE (jmatejcz) recursion limit calculated as (all_nodes_num - 2) * required tool calls
93+
# -2 because we don't want to include START and END node
94+
# then we add numer of additional calls that can be made
95+
# and +2 as we have to pass once though START and END
96+
97+
recurssion_limit = (
98+
(len(agent.get_graph().nodes) - 2) * task.required_calls
99+
+ task.additional_calls
100+
+ 2
101+
)
98102
config: RunnableConfig = {
99103
"run_id": run_id,
100104
"callbacks": callbacks,
@@ -113,40 +117,27 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
113117
messages: List[BaseMessage] = []
114118
prev_count: int = 0
115119
try:
116-
with self.time_limit(20 * task.max_tool_calls_number):
117-
if isinstance(task, SpatialReasoningAgentTask):
118-
for state in agent.stream(
119-
{
120-
"messages": [
121-
HumanMultimodalMessage(
122-
content=task.get_prompt(), images=task.get_images()
123-
)
124-
]
125-
},
126-
config=config,
127-
):
128-
node = next(iter(state))
129-
all_messages = state[node]["messages"]
130-
for new_msg in all_messages[prev_count:]:
131-
messages.append(new_msg)
132-
prev_count = len(messages)
133-
else:
134-
for state in agent.stream(
135-
{
136-
"messages": [
137-
HumanMultimodalMessage(content=task.get_prompt())
138-
]
139-
},
140-
config=config,
141-
):
142-
node = next(iter(state))
120+
with self.time_limit(200 * task.max_tool_calls_number):
121+
for state in agent.stream(
122+
initial_state,
123+
config=config,
124+
):
125+
node = next(iter(state))
126+
if "messages" in state[node]:
143127
all_messages = state[node]["messages"]
144128
for new_msg in all_messages[prev_count:]:
145129
messages.append(new_msg)
130+
if isinstance(new_msg, AIMessage):
131+
self.logger.debug(
132+
f"Message from node '{node}': {new_msg.content}, tool_calls: {new_msg.tool_calls}"
133+
)
146134
prev_count = len(messages)
147135
except TimeoutException as e:
148136
self.logger.error(msg=f"Task timeout: {e}")
149137
except GraphRecursionError as e:
138+
tool_calls = task.get_tool_calls_from_messages(messages=messages)
139+
score = task.validate(tool_calls=tool_calls)
140+
score = 0.0
150141
self.logger.error(msg=f"Reached recursion limit {e}")
151142

152143
tool_calls = task.get_tool_calls_from_messages(messages=messages)

src/rai_bench/rai_bench/tool_calling_agent/interfaces.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,14 @@ def max_tool_calls_number(self) -> int:
585585
+ self.extra_tool_calls
586586
)
587587

588+
@property
589+
def additional_calls(self) -> int:
590+
"""number of additional calls that can be done to still pass task.
591+
Includes extra tool calls params.
592+
and optional tool calls number which depends on task.
593+
"""
594+
return self.optional_tool_calls_number + self.extra_tool_calls
595+
588596
@property
589597
def required_calls(self) -> int:
590598
"""Minimal number of calls required to complete task"""

0 commit comments

Comments
 (0)