|
1 | 1 |
|
2 | | -from typing import Any, Optional, Annotated, Sequence |
| 2 | +from typing import Any, Optional, Annotated, Sequence, Dict |
3 | 3 | import operator |
4 | 4 |
|
5 | 5 | import pandas as pd |
@@ -63,8 +63,10 @@ class MLflowToolsAgent(BaseAgent): |
63 | 63 | The tracking URI for MLflow. Defaults to None. |
64 | 64 | mlflow_registry_uri : str, optional |
65 | 65 | The registry URI for MLflow. Defaults to None. |
66 | | - **react_agent_kwargs : dict, optional |
67 | | - Additional keyword arguments to pass to the agent's react agent. |
| 66 | + react_agent_kwargs : dict |
| 67 | + Additional keyword arguments to pass to the create_react_agent function. |
| 68 | + invoke_react_agent_kwargs : dict |
| 69 | + Additional keyword arguments to pass to the invoke method of the react agent. |
68 | 70 | |
69 | 71 | Methods: |
70 | 72 | -------- |
@@ -114,13 +116,15 @@ def __init__( |
114 | 116 | model: Any, |
115 | 117 | mlflow_tracking_uri: Optional[str]=None, |
116 | 118 | mlflow_registry_uri: Optional[str]=None, |
117 | | - **react_agent_kwargs, |
| 119 | + create_react_agent_kwargs: Optional[Dict]={}, |
| 120 | + invoke_react_agent_kwargs: Optional[Dict]={}, |
118 | 121 | ): |
119 | 122 | self._params = { |
120 | 123 | "model": model, |
121 | 124 | "mlflow_tracking_uri": mlflow_tracking_uri, |
122 | 125 | "mlflow_registry_uri": mlflow_registry_uri, |
123 | | - **react_agent_kwargs, |
| 126 | + "create_react_agent_kwargs": create_react_agent_kwargs, |
| 127 | + "invoke_react_agent_kwargs": invoke_react_agent_kwargs, |
124 | 128 | } |
125 | 129 | self._compiled_graph = self._make_compiled_graph() |
126 | 130 | self.response = None |
@@ -185,8 +189,6 @@ def invoke_agent( |
185 | 189 | The user instructions to pass to the agent. |
186 | 190 | data_raw : pd.DataFrame, optional |
187 | 191 | The raw data to pass to the agent. Used for prediction and tool calls where data is required. |
188 | | - kwargs : dict, optional |
189 | | - Additional keyword arguments to pass to the agents invoke method. |
190 | 192 | |
191 | 193 | """ |
192 | 194 | response = self._compiled_graph.invoke( |
@@ -234,10 +236,30 @@ def make_mlflow_tools_agent( |
234 | 236 | model: Any, |
235 | 237 | mlflow_tracking_uri: str=None, |
236 | 238 | mlflow_registry_uri: str=None, |
237 | | - **react_agent_kwargs, |
| 239 | + create_react_agent_kwargs: Optional[Dict]={}, |
| 240 | + invoke_react_agent_kwargs: Optional[Dict]={}, |
238 | 241 | ): |
239 | 242 | """ |
240 | 243 | MLflow Tool Calling Agent |
| 244 | + |
| 245 | + Parameters: |
| 246 | + ---------- |
| 247 | + model : Any |
| 248 | + The language model used to generate the agent. |
| 249 | + mlflow_tracking_uri : str, optional |
| 250 | + The tracking URI for MLflow. Defaults to None. |
| 251 | + mlflow_registry_uri : str, optional |
| 252 | + The registry URI for MLflow. Defaults to None. |
| 253 | + create_react_agent_kwargs : dict, optional |
| 254 | + Additional keyword arguments to pass to the agent's create_react_agent method. |
| 255 | + invoke_react_agent_kwargs : dict, optional |
| 256 | + Additional keyword arguments to pass to the agent's invoke method. |
| 257 | + |
| 258 | + Returns |
| 259 | + ------- |
| 260 | + app : langchain.graphs.CompiledStateGraph |
| 261 | + A compiled state graph for the MLflow Tool Calling Agent. |
| 262 | + |
241 | 263 | """ |
242 | 264 |
|
243 | 265 | try: |
@@ -274,14 +296,15 @@ def mflfow_tools_agent(state): |
274 | 296 | model, |
275 | 297 | tools=tool_node, |
276 | 298 | state_schema=GraphState, |
277 | | - **react_agent_kwargs, |
| 299 | + **create_react_agent_kwargs, |
278 | 300 | ) |
279 | 301 |
|
280 | 302 | response = mlflow_agent.invoke( |
281 | 303 | { |
282 | 304 | "messages": [("user", state["user_instructions"])], |
283 | 305 | "data_raw": state["data_raw"], |
284 | 306 | }, |
| 307 | + invoke_react_agent_kwargs, |
285 | 308 | ) |
286 | 309 |
|
287 | 310 | print(" * POST-PROCESS RESULTS") |
|
0 commit comments