Skip to content

Commit 6d59dc9

Browse files
committed
Create per session ephemeral database
1 parent aaa32a4 commit 6d59dc9

File tree

2 files changed

+51
-49
lines changed

2 files changed

+51
-49
lines changed

samples/adk-sql-agent/sql_agent/agent.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
# limitations under the License.
1414

1515
from google.adk.agents import Agent
16-
from google.adk.agents import BaseAgent
17-
import tempfile
18-
from .tools import create_run_sql_tool
19-
16+
from .tools import run_sql_tool, create_database_tool
2017

2118
import sqlite3
2219

@@ -38,6 +35,9 @@
3835
- Always prefer to insert multiple rows in a single call to the sql_db_query tool, if possible.
3936
- You may request to execute multiple sql_db_query tool calls which will be run in parallel.
4037
38+
You can use run_sql_tool to run sql queries against the current sqlite3 database.
39+
You can use create_database_tool to create a new ephemeral sqlite3 database if one is not found in current state.
40+
You should always check if database for current session exist before running sql queries by calling create_database_tool.
4141
If you make a mistake, try to recover."""
4242

4343
INTRO_TEXT = """\
@@ -61,19 +61,10 @@
6161
---
6262
"""
6363

64-
def get_dbpath(thread_id: str) -> str:
65-
# Ephemeral sqlite database per conversation thread
66-
_, path = tempfile.mkstemp(suffix=".db")
67-
return path
68-
69-
70-
# TODO: how to get the session from within a callback.
71-
dbpath = get_dbpath("default")
72-
7364
root_agent = Agent(
7465
name="weather_time_agent",
7566
model="gemini-2.0-flash",
7667
description=INTRO_TEXT,
7768
instruction=SYSTEM_PROMPT,
78-
tools=[create_run_sql_tool(dbpath)],
69+
tools=[run_sql_tool, create_database_tool],
7970
)

samples/adk-sql-agent/sql_agent/tools.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
import logging
1717
from typing import Any, NotRequired, TypedDict
1818
from opentelemetry import trace
19+
from google.adk.tools import ToolContext
1920
import sqlite3
21+
import tempfile
2022

23+
SESSION_DB_KEY = "session_sqlite_db"
2124

2225
tracer = trace.get_tracer(__name__)
2326
logger = logging.getLogger(__name__)
@@ -30,42 +33,50 @@ class SqlRunResult(TypedDict):
3033
rows: NotRequired[list[tuple[str, ...]]]
3134
"""The rows returned by the SQL query"""
3235

36+
@tracer.start_as_current_span("create_database")
37+
def create_database_tool(tool_context: ToolContext):
38+
"""Creates a temporary file in the /tmp directory to hold an ephemeral
39+
sqlite3 database if a database is not found for the current session.
40+
"""
41+
if not SESSION_DB_KEY in tool_context.state:
42+
_, path = tempfile.mkstemp(suffix=".db")
43+
# No scope prefix in the state data indicates that it will be persisted for
44+
# current session.
45+
# See https://deepwiki.com/google/adk-python/3.4-state-management.
46+
tool_context.state[SESSION_DB_KEY] = path
3347

34-
def create_run_sql_tool(dbpath: str):
35-
@tracer.start_as_current_span("run_sql")
36-
def run_sql(sql_query: str) -> dict[str, Any]:
37-
"""Runs a SQLite query. The SQL query can be DDL or DML. Returns the rows if it's a SELECT query."""
48+
@tracer.start_as_current_span("run_sql")
49+
def run_sql_tool(sql_query: str, tool_context: ToolContext) -> dict[str, Any]:
50+
"""Runs a SQLite query. The SQL query can be DDL or DML. Returns the rows if it's a SELECT query."""
51+
current_session_db_path = tool_context.state[SESSION_DB_KEY]
52+
with sqlite3.connect(current_session_db_path) as db:
53+
try:
54+
cursor = db.cursor()
55+
cursor.execute(sql_query)
56+
rows_list: list[tuple[str, ...]] = []
3857

39-
with sqlite3.connect(dbpath) as db:
40-
try:
41-
cursor = db.cursor()
42-
cursor.execute(sql_query)
43-
rows_list: list[tuple[str, ...]] = []
44-
45-
# Check if the query is one that would return rows (e.g., SELECT)
46-
if cursor.description is not None:
47-
fetched_rows: list[tuple[Any, ...]] = cursor.fetchall()
48-
rows_list = [tuple(str(col) for col in row) for row in fetched_rows]
49-
logger.info("Query returned %s rows", len(rows_list))
50-
else:
51-
# For DDL/DML (like INSERT, UPDATE, DELETE without RETURNING clause)
52-
# cursor.description is None.
53-
# rowcount shows number of affected rows for DML.
54-
logger.info("Query affected %s rows (DDL/DML)", cursor.rowcount)
58+
# Check if the query is one that would return rows (e.g., SELECT)
59+
if cursor.description is not None:
60+
fetched_rows: list[tuple[Any, ...]] = cursor.fetchall()
61+
rows_list = [tuple(str(col) for col in row) for row in fetched_rows]
62+
logger.info("Query returned %s rows", len(rows_list))
63+
else:
64+
# For DDL/DML (like INSERT, UPDATE, DELETE without RETURNING clause)
65+
# cursor.description is None.
66+
# rowcount shows number of affected rows for DML.
67+
logger.info("Query affected %s rows (DDL/DML)", cursor.rowcount)
5568

56-
# DML statements (INSERT, UPDATE, DELETE) require a commit.
57-
# DDL statements are often autocommitted by SQLite, but an explicit commit here ensures DML changes are saved.
58-
db.commit()
59-
return {"rows": rows_list}
69+
# DML statements (INSERT, UPDATE, DELETE) require a commit.
70+
# DDL statements are often autocommitted by SQLite, but an explicit commit here ensures DML changes are saved.
71+
db.commit()
72+
return {"rows": rows_list}
6073

61-
except sqlite3.Error as err:
62-
logger.error(f"SQL Error: {err} for query: {sql_query}")
63-
try:
64-
db.rollback() # Attempt to rollback on error
65-
logger.info("SQL transaction rolled back due to error.")
66-
except sqlite3.Error as rb_err:
67-
# This can happen if the connection is already closed or in an unusable state.
68-
logger.error(f"Failed to rollback transaction: {rb_err}")
69-
return {"error": str(err)}
70-
71-
return run_sql
74+
except sqlite3.Error as err:
75+
logger.error(f"SQL Error: {err} for query: {sql_query}")
76+
try:
77+
db.rollback() # Attempt to rollback on error
78+
logger.info("SQL transaction rolled back due to error.")
79+
except sqlite3.Error as rb_err:
80+
# This can happen if the connection is already closed or in an unusable state.
81+
logger.error(f"Failed to rollback transaction: {rb_err}")
82+
return {"error": str(err)}

0 commit comments

Comments
 (0)