Command use with modified community tools #3714
-
Hi, I am trying to return a command from a modified tool in the community package (QuerySQLDatabaseTool) but struggle to understand how to integrate it. This is my agent: from typing import Annotated, Any
from typing_extensions import TypedDict
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.language_models import BaseChatModel
from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from src.tools.sql_tools import get_custom_sql_database_tools
class State(TypedDict):
messages: Annotated[list, add_messages]
sql_query: str | None
sql_result: str | None
python_result: str | None
class ReactPythonGraph(BaseGraph):
def __init__(
self,
config: ReactPythonGraphConfig,
chat_model: BaseChatModel,
tools: list[BaseTool] | None = None,
system_message: SystemMessage | None = None,
initial_state: dict | None = None,
*args,
**kwargs,
):
if not tools:
tools = get_custom_sql_database_tools(
config.db_config, chat_model=chat_model
)
tools.append(get_python_tools())
if not system_message:
system_message = SystemMessage(content=config.system_message)
self.chat_model = chat_model
self.chat_node = ChatNode(
chat_model=chat_model.bind_tools(tools), system_message=system_message
)
# Extract binded tools from the ChatNode that is instantiated in the graph
self.binded_tools = self.chat_node.binded_tools
self.system_message = system_message
initial_state = initial_state or {"messages": [], "sql_query": None}
def init_node(_: Any) -> State:
return State({"messages": [], "sql_query": initial_state.get("sql_query")})
graph_builder = StateGraph(State)
graph_builder.add_node("init", init_node)
graph_builder.add_node("chat", self.chat_node)
graph_builder.add_node("tools", ToolNode(tools=tools))
graph_builder.add_edge("init", "chat")
graph_builder.add_edge("tools", "chat")
graph_builder.add_conditional_edges("chat", tools_condition)
graph_builder.set_entry_point("init")
state_graph = graph_builder.compile()
super().__init__(config=config, state_graph=state_graph, *args, **kwargs) And here is my sql tool which I would like to return a command to update the state variable sql_query: from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import (
InfoSQLDatabaseTool,
BaseSQLDatabaseTool,
QuerySQLDatabaseTool,
QuerySQLCheckerTool,
)
class QuerySQLDatabaseToolWithHeaders(QuerySQLDatabaseTool):
@tool
def execute_sql(
self, query: str, tool_call_id: Annotated[str, InjectedToolCallId]
) -> Union[str, Command]:
"""
Executes a SQL query with a TOP clause to limit results and updates the state.
Args:
query (str): The SQL query to execute.
tool_call_id (str): The unique identifier for the tool call.
Returns:
Union[str, Command]: A Command object to update the state or an error message.
"""
return self._run(query, tool_call_id)
def _run(
self,
query: str,
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Union[str, Command]:
"""
Executes a SQL query with a TOP clause to limit results and updates the state.
Args:
query (str): The SQL query to execute.
tool_call_id (str): The unique identifier for the tool call.
config (RunnableConfig): The runtime configuration.
Returns:
Union[str, Command]: A Command object to update the state or an error message.
"""
print("Called the tool")
try:
# Check for existing TOP clause
top_match = re.search(r"\bTOP\s+(\d+)", query, re.IGNORECASE)
if top_match:
current_top = int(top_match.group(1))
if current_top <= 100000:
# Keep the existing TOP if it's 100k or lower
query_with_top = query
else:
# Replace the existing TOP with 100k if it's higher
query_with_top = re.sub(
r"\bTOP\s+\d+", "TOP 100000", query, flags=re.IGNORECASE
)
else:
# Add TOP 100k if no TOP clause exists
# We need to insert TOP 100k after the first SELECT
query_with_top = re.sub(
r"\bSELECT\b",
"SELECT TOP 100000",
query,
count=1,
flags=re.IGNORECASE,
)
with self.db._engine.connect() as connection:
result = connection.execute(text(query_with_top))
headers = list(result.keys())
rows = list(result.fetchall())
if not rows:
return "Error: No data returned from the query."
# Create the Command to update the state
print("type query: ", type(query_with_top))
return Command(
update={
"sql_query": query_with_top, # Store query in state
"sql_result": [headers] + [list(row) for row in rows],
"messages": [
ToolMessage(
"Query executed successfully.",
tool_call_id=tool_call_id,
)
],
}
)
except Exception as e:
return f"Error executing query: {str(e)}"
def get_tools(db, llm) -> List[BaseTool]:
"""Get the tools for working with views."""
# Tool 3
query_sql_database_tool_description = (
"Input to this tool is a detailed and correct SQL query, output is a "
"result from the database including headers. If the query is not correct, an error message "
"will be returned. If an error is returned, rewrite the query, check the "
"query, and try again. If you encounter an issue with Unknown column "
f"'xxxx' in 'field list', use {view_info_tool.name} "
"to query the correct table fields."
)
query_sql_database_tool = QuerySQLDatabaseToolWithHeaders(
db=db, description=query_sql_database_tool_description
)
return [
query_sql_database_tool,
#.... more tools
]
def get_custom_sql_database_tools(
database_config: DatabaseConfig, chat_model: BaseChatModel
) -> list[BaseTool]:
database_engine = SQLDatabase.from_uri(database_config.get_uri())
sql_tools = get_tools(database_engine, chat_model)
return sql_tools I also tried to annotated the _run() method with @tool but then face the issue that custom _run method is not being called but the _run() from the QuerySQLDatabaseTool. Therefore, I thought having a wrapper tool that calls the _run() would be the way to go but that gives me a pydantic error that a non-annotated attribute was detected: "execute_sql = StructuredTool(name='execute_sql', description='Executes a SQL query with a TOP clause to limit results and updates the state.\n\nArgs:\n query (str): The SQL query to execute.\n tool_call_id (str): The unique identifier for the tool call.\n\nReturns:\n Union[str, Command]: A Command object to update the state or an error message.', args_schema=<class 'langchain_core.utils.pydantic.execute_sql'>, func=<function QuerySQLDatabaseToolWithHeaders.execute_sql at 0x7fd56da71b40>). All model fields require a type annotation; if execute_sql is not meant to be a field, you may be able to resolve this error by annotating it as a ClassVar or updating model_config['ignored_types'] I also tried to add the ClassVar without success and when returning the Command without declaring the _run() as @tool the tool_call_id is not being injected and the tool call returns in an error and the tool_call_id seems to be a mandatory field to pass for the command My langchain version is 0.3.20 Thanks a lot for your help! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In case someone is facing this in the future: my solution was to not use the community tool here the code. def create_query_sql_database_tool(db, description: str):
"""
Factory function to create a SQL query execution tool with an internal DB connection.
This ensures that the LLM does not need to pass the `db` object directly.
"""
@tool
async def query_sql_database(
query: str,
variable_name: str,
state: Annotated[dict, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId],
config: RunnableConfig = None, # Optional, keeps tool compatible with LangGraph
) -> Union[str, Dict[str, Any]]:
"""
Executes a SQL query with a TOP clause to limit results.
## Behavior:
- **Ensures a `TOP 100000` clause is present** in queries for performance optimization.
- **Returns the modified query** so the LLM understands the enforced TOP limit.
- **Stores results in `sql_result`** to be used in follow-up processing.
"""
try:
if state.get("sql_execution_count", 0) >= MAX_QUERIES_IN_SINGLE_INVOKATION:
return (
f"Execution limit reached: You've exceeded the maximum of {MAX_QUERIES_IN_SINGLE_INVOKATION} SQL executions. "
"Please inform the user that for each invokation there is a maximum of queries to run."
)
query_with_top = ensure_sql_execution_stays_feasible(query=query)
# Execute query
with db._engine.connect() as connection:
result = connection.execute(text(query_with_top))
headers = list(result.keys())
rows = list(result.fetchall())
# df = pd.read_sql(text(query_with_top), con=connection)
if not rows:
return "Error: No data returned from the query."
# Check if result size is too large
if len(rows) > 75:
return (
f"Error: Query result is too large ({len(rows)} rows). "
"Please use the query_sql_database_and_upload_to_s3 tool instead."
)
# print("Result ", [headers] + [list(row) for row in rows[:10]])
# Preserve previous results and query
sql_result_merged = state.get("sql_result", {}) or {}
sql_result_merged[variable_name] = [headers] + [
list(row) for row in rows
]
sql_query_merged = state.get("sql_query", {}) or {}
sql_query_merged[variable_name] = query_with_top
return Command(
update={
# update the state keys
"sql_query": sql_query_merged,
"sql_result": sql_result_merged,
"sql_execution_count": (state.get("sql_execution_count") or 0) + 1,
# "sql_result": {variable_name: df},
# update the message history
"messages": [
ToolMessage(
f"Successfully executed SQL, value stored in state['sql_result'].{variable_name}. SQL Result: {[headers] + [list(row) for row in rows]}",
tool_call_id=tool_call_id,
)
],
}
)
except Exception as e:
return f"Error executing query: {str(e)}"
# Assign the description to the tool before returning it
query_sql_database.description = description
return query_sql_database # Return the configured tool |
Beta Was this translation helpful? Give feedback.
In case someone is facing this in the future: my solution was to not use the community tool here the code.