Skip to content

Commit 5fb9e09

Browse files
committed
Tool Calling: update react agent kwargs
1 parent 6eb158a commit 5fb9e09

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

ai_data_science_team/agents/data_loader_tools_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def make_data_loader_tools_agent(
196196
197197
Returns:
198198
--------
199-
Data Loader Agent
199+
app : langchain.graphs.CompiledStateGraph
200200
An agent that can interact with data loading tools.
201201
"""
202202

ai_data_science_team/ml_agents/mlflow_tools_agent.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
from typing import Any, Optional, Annotated, Sequence
2+
from typing import Any, Optional, Annotated, Sequence, Dict
33
import operator
44

55
import pandas as pd
@@ -63,8 +63,10 @@ class MLflowToolsAgent(BaseAgent):
6363
The tracking URI for MLflow. Defaults to None.
6464
mlflow_registry_uri : str, optional
6565
The registry URI for MLflow. Defaults to None.
66-
**react_agent_kwargs : dict, optional
67-
Additional keyword arguments to pass to the agent's react agent.
66+
react_agent_kwargs : dict
67+
Additional keyword arguments to pass to the create_react_agent function.
68+
invoke_react_agent_kwargs : dict
69+
Additional keyword arguments to pass to the invoke method of the react agent.
6870
6971
Methods:
7072
--------
@@ -114,13 +116,15 @@ def __init__(
114116
model: Any,
115117
mlflow_tracking_uri: Optional[str]=None,
116118
mlflow_registry_uri: Optional[str]=None,
117-
**react_agent_kwargs,
119+
create_react_agent_kwargs: Optional[Dict]={},
120+
invoke_react_agent_kwargs: Optional[Dict]={},
118121
):
119122
self._params = {
120123
"model": model,
121124
"mlflow_tracking_uri": mlflow_tracking_uri,
122125
"mlflow_registry_uri": mlflow_registry_uri,
123-
**react_agent_kwargs,
126+
"create_react_agent_kwargs": create_react_agent_kwargs,
127+
"invoke_react_agent_kwargs": invoke_react_agent_kwargs,
124128
}
125129
self._compiled_graph = self._make_compiled_graph()
126130
self.response = None
@@ -185,8 +189,6 @@ def invoke_agent(
185189
The user instructions to pass to the agent.
186190
data_raw : pd.DataFrame, optional
187191
The raw data to pass to the agent. Used for prediction and tool calls where data is required.
188-
kwargs : dict, optional
189-
Additional keyword arguments to pass to the agents invoke method.
190192
191193
"""
192194
response = self._compiled_graph.invoke(
@@ -234,10 +236,30 @@ def make_mlflow_tools_agent(
234236
model: Any,
235237
mlflow_tracking_uri: str=None,
236238
mlflow_registry_uri: str=None,
237-
**react_agent_kwargs,
239+
create_react_agent_kwargs: Optional[Dict]={},
240+
invoke_react_agent_kwargs: Optional[Dict]={},
238241
):
239242
"""
240243
MLflow Tool Calling Agent
244+
245+
Parameters:
246+
----------
247+
model : Any
248+
The language model used to generate the agent.
249+
mlflow_tracking_uri : str, optional
250+
The tracking URI for MLflow. Defaults to None.
251+
mlflow_registry_uri : str, optional
252+
The registry URI for MLflow. Defaults to None.
253+
create_react_agent_kwargs : dict, optional
254+
Additional keyword arguments to pass to the agent's create_react_agent method.
255+
invoke_react_agent_kwargs : dict, optional
256+
Additional keyword arguments to pass to the agent's invoke method.
257+
258+
Returns
259+
-------
260+
app : langchain.graphs.CompiledStateGraph
261+
A compiled state graph for the MLflow Tool Calling Agent.
262+
241263
"""
242264

243265
try:
@@ -274,14 +296,15 @@ def mflfow_tools_agent(state):
274296
model,
275297
tools=tool_node,
276298
state_schema=GraphState,
277-
**react_agent_kwargs,
299+
**create_react_agent_kwargs,
278300
)
279301

280302
response = mlflow_agent.invoke(
281303
{
282304
"messages": [("user", state["user_instructions"])],
283305
"data_raw": state["data_raw"],
284306
},
307+
invoke_react_agent_kwargs,
285308
)
286309

287310
print(" * POST-PROCESS RESULTS")

0 commit comments

Comments
 (0)