19
19
from typing import Iterator , List , Optional , Sequence , Tuple
20
20
21
21
from langchain_core .language_models import BaseChatModel
22
- from langchain_core .messages import BaseMessage
22
+ from langchain_core .messages import AIMessage , BaseMessage
23
23
from langchain_core .runnables .config import RunnableConfig
24
24
from langgraph .errors import GraphRecursionError
25
25
from langgraph .graph .state import CompiledStateGraph
26
26
from rai .agents .langchain .core import (
27
27
create_conversational_agent ,
28
28
)
29
- from rai .messages import HumanMultimodalMessage
29
+ from rai .agents . langchain . core . react_agent import ReActAgentState
30
30
31
31
from rai_bench .agents import create_multimodal_to_tool_agent
32
32
from rai_bench .base_benchmark import BaseBenchmark , TimeoutException
38
38
TaskResult ,
39
39
ToolCallingAgentRunSummary ,
40
40
)
41
- from rai_bench .tool_calling_agent .tasks .spatial import (
42
- SpatialReasoningAgentTask ,
43
- )
44
41
from rai_bench .utils import get_llm_model_name
45
42
46
43
@@ -67,7 +64,12 @@ def __init__(
67
64
self .tasks_results : List [TaskResult ] = []
68
65
self .csv_initialize (self .results_filename , TaskResult )
69
66
70
- def run_next (self , agent : CompiledStateGraph , experiment_id : uuid .UUID ) -> None :
67
+ def run_next (
68
+ self ,
69
+ agent : CompiledStateGraph ,
70
+ initial_state : ReActAgentState ,
71
+ experiment_id : uuid .UUID ,
72
+ ) -> None :
71
73
"""Runs the next task of the benchmark.
72
74
73
75
Parameters
@@ -87,14 +89,16 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
87
89
)
88
90
callbacks = self .score_tracing_handler .get_callbacks ()
89
91
run_id = uuid .uuid4 ()
90
- # NOTE (jmatejcz) recursion limit calculated as all_nodes_num -> one pass though whole node
91
- # plus (task.max_tool_calls_number-1 because the first pass is already added in)
92
- # times number of nodes - 2 because we dont cout start and end node
93
- # this can be to much for larger graphs that dont use all nodes on extra calls
94
- # in such ase adjust this value
95
- recurssion_limit = len (agent .get_graph ().nodes ) + (
96
- task .max_tool_calls_number - 1
97
- ) * (len (agent .get_graph ().nodes ) - 2 )
92
+ # NOTE (jmatejcz) recursion limit calculated as (all_nodes_num - 2) * required tool calls
93
+ # -2 because we don't want to include START and END node
94
+ # then we add numer of additional calls that can be made
95
+ # and +2 as we have to pass once though START and END
96
+
97
+ recurssion_limit = (
98
+ (len (agent .get_graph ().nodes ) - 2 ) * task .required_calls
99
+ + task .additional_calls
100
+ + 2
101
+ )
98
102
config : RunnableConfig = {
99
103
"run_id" : run_id ,
100
104
"callbacks" : callbacks ,
@@ -113,40 +117,27 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
113
117
messages : List [BaseMessage ] = []
114
118
prev_count : int = 0
115
119
try :
116
- with self .time_limit (20 * task .max_tool_calls_number ):
117
- if isinstance (task , SpatialReasoningAgentTask ):
118
- for state in agent .stream (
119
- {
120
- "messages" : [
121
- HumanMultimodalMessage (
122
- content = task .get_prompt (), images = task .get_images ()
123
- )
124
- ]
125
- },
126
- config = config ,
127
- ):
128
- node = next (iter (state ))
129
- all_messages = state [node ]["messages" ]
130
- for new_msg in all_messages [prev_count :]:
131
- messages .append (new_msg )
132
- prev_count = len (messages )
133
- else :
134
- for state in agent .stream (
135
- {
136
- "messages" : [
137
- HumanMultimodalMessage (content = task .get_prompt ())
138
- ]
139
- },
140
- config = config ,
141
- ):
142
- node = next (iter (state ))
120
+ with self .time_limit (200 * task .max_tool_calls_number ):
121
+ for state in agent .stream (
122
+ initial_state ,
123
+ config = config ,
124
+ ):
125
+ node = next (iter (state ))
126
+ if "messages" in state [node ]:
143
127
all_messages = state [node ]["messages" ]
144
128
for new_msg in all_messages [prev_count :]:
145
129
messages .append (new_msg )
130
+ if isinstance (new_msg , AIMessage ):
131
+ self .logger .debug (
132
+ f"Message from node '{ node } ': { new_msg .content } , tool_calls: { new_msg .tool_calls } "
133
+ )
146
134
prev_count = len (messages )
147
135
except TimeoutException as e :
148
136
self .logger .error (msg = f"Task timeout: { e } " )
149
137
except GraphRecursionError as e :
138
+ tool_calls = task .get_tool_calls_from_messages (messages = messages )
139
+ score = task .validate (tool_calls = tool_calls )
140
+ score = 0.0
150
141
self .logger .error (msg = f"Reached recursion limit { e } " )
151
142
152
143
tool_calls = task .get_tool_calls_from_messages (messages = messages )
0 commit comments