Skip to content

Commit 46b7dc8

Browse files
committed
Refactor DBHandler to support dialect-specific required fields and update example config for SQLite
1 parent 3e6c61e commit 46b7dc8

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

besser/agent/db/db_handler.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ class DBHandler:
2020
configuration and opens connections lazily when they are first needed.
2121
"""
2222

23-
_REQUIRED_DB_FIELDS = {'dialect', 'host', 'port', 'database', 'username', 'password'}
23+
_DEFAULT_REQUIRED_DB_FIELDS = {'dialect', 'host', 'port', 'database', 'username', 'password'}
24+
_DIALECT_REQUIRED_DB_FIELDS = {
25+
'sqlite': {'dialect', 'file'},
26+
'postgres': {'dialect', 'host', 'port', 'database', 'username', 'password'}
27+
}
2428
_SQL_COMMAND_PATTERN = re.compile(r'^\s*(?:/\*.*?\*/\s*|--.*?(?:\n|\r\n?)\s*)*([a-zA-Z]+)', re.DOTALL)
2529

2630
def __init__(self, agent: 'Agent'):
@@ -50,24 +54,36 @@ def _extract_db_configs(self, config: dict[str, Any]) -> dict[str, dict[str, Any
5054
return db_configs
5155

5256
def _validate_db_config(self, db_name: str, db_config: dict[str, Any]) -> bool:
53-
missing = sorted(field for field in self._REQUIRED_DB_FIELDS if db_config.get(field) is None)
57+
required_fields = self._required_db_fields(db_config)
58+
missing = sorted(field for field in required_fields if db_config.get(field) is None)
5459
if missing:
60+
dialect = db_config.get('dialect')
5561
logger.error(
56-
"Missing required DB properties for '%s': %s. Expected under 'db.sql.%s'.",
62+
"Missing required DB properties for '%s' (dialect=%s): %s. Expected under 'db.sql.%s'. Required fields: %s.",
5763
db_name,
64+
dialect,
5865
', '.join(missing),
5966
db_name,
67+
', '.join(sorted(required_fields)),
6068
)
6169
return False
6270
return True
6371

72+
def _required_db_fields(self, db_config: dict[str, Any]) -> set[str]:
73+
dialect = db_config.get('dialect')
74+
if dialect is None:
75+
return {'dialect'}
76+
77+
dialect_name = str(dialect).split('+', 1)[0].lower()
78+
return self._DIALECT_REQUIRED_DB_FIELDS.get(dialect_name, self._DEFAULT_REQUIRED_DB_FIELDS)
79+
6480
def _build_db_url(self, db_name: str) -> URL | str | None:
6581
try:
6682
db_config = self._db_configs[db_name]
6783
dialect = str(db_config['dialect'])
6884

6985
if dialect.startswith('sqlite'):
70-
return URL.create(drivername=dialect, database=str(db_config['database']))
86+
return URL.create(drivername=dialect, database=str(db_config['file']))
7187

7288
return URL.create(
7389
drivername=dialect,

besser/agent/test/examples/config.yaml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,8 @@ db:
6060
password: YOUR-DB-PASSWORD
6161
sql:
6262
- db1:
63-
dialect: postgresql
64-
host: YOUR-DB-HOST
65-
port: 5432
66-
database: YOUR-DB-NAME
67-
username: YOUR-DB-USERNAME
68-
password: YOUR-DB-PASSWORD
63+
dialect: sqlite
64+
database: path/to/local.db
6965
- db2:
7066
dialect: postgresql
7167
host: YOUR-DB-HOST

docs/source/wiki/db/db_handler.rst

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,8 @@ You can define multiple DBs in configuration:
3535
username: user
3636
password: pass
3737
- db2:
38-
dialect: postgresql
39-
host: localhost
40-
port: 5432
41-
database: analytics_db
42-
username: user
43-
password: pass
38+
dialect: sqlite
39+
file: path/to/local_database.db
4440
4541
Query a DB in a State Body
4642
--------------------------

0 commit comments

Comments
 (0)