@@ -20,7 +20,11 @@ class DBHandler:
2020 configuration and opens connections lazily when they are first needed.
2121 """
2222
23- _REQUIRED_DB_FIELDS = {'dialect' , 'host' , 'port' , 'database' , 'username' , 'password' }
23+ _DEFAULT_REQUIRED_DB_FIELDS = {'dialect' , 'host' , 'port' , 'database' , 'username' , 'password' }
24+ _DIALECT_REQUIRED_DB_FIELDS = {
25+ 'sqlite' : {'dialect' , 'file' },
26+ 'postgres' : {'dialect' , 'host' , 'port' , 'database' , 'username' , 'password' }
27+ }
2428 _SQL_COMMAND_PATTERN = re .compile (r'^\s*(?:/\*.*?\*/\s*|--.*?(?:\n|\r\n?)\s*)*([a-zA-Z]+)' , re .DOTALL )
2529
2630 def __init__ (self , agent : 'Agent' ):
@@ -50,24 +54,36 @@ def _extract_db_configs(self, config: dict[str, Any]) -> dict[str, dict[str, Any
5054 return db_configs
5155
5256 def _validate_db_config (self , db_name : str , db_config : dict [str , Any ]) -> bool :
53- missing = sorted (field for field in self ._REQUIRED_DB_FIELDS if db_config .get (field ) is None )
57+ required_fields = self ._required_db_fields (db_config )
58+ missing = sorted (field for field in required_fields if db_config .get (field ) is None )
5459 if missing :
60+ dialect = db_config .get ('dialect' )
5561 logger .error (
56- "Missing required DB properties for '%s': %s. Expected under 'db.sql.%s'." ,
62+ "Missing required DB properties for '%s' (dialect=%s) : %s. Expected under 'db.sql.%s'. Required fields: %s ." ,
5763 db_name ,
64+ dialect ,
5865 ', ' .join (missing ),
5966 db_name ,
67+ ', ' .join (sorted (required_fields )),
6068 )
6169 return False
6270 return True
6371
72+ def _required_db_fields (self , db_config : dict [str , Any ]) -> set [str ]:
73+ dialect = db_config .get ('dialect' )
74+ if dialect is None :
75+ return {'dialect' }
76+
77+ dialect_name = str (dialect ).split ('+' , 1 )[0 ].lower ()
78+ return self ._DIALECT_REQUIRED_DB_FIELDS .get (dialect_name , self ._DEFAULT_REQUIRED_DB_FIELDS )
79+
6480 def _build_db_url (self , db_name : str ) -> URL | str | None :
6581 try :
6682 db_config = self ._db_configs [db_name ]
6783 dialect = str (db_config ['dialect' ])
6884
6985 if dialect .startswith ('sqlite' ):
70- return URL .create (drivername = dialect , database = str (db_config ['database ' ]))
86+ return URL .create (drivername = dialect , database = str (db_config ['file ' ]))
7187
7288 return URL .create (
7389 drivername = dialect ,
0 commit comments