1010
1111from __future__ import annotations
1212
13- from dataclasses import dataclass , field
13+ from dataclasses import dataclass , field , fields
1414from enum import Enum
15- from typing import TYPE_CHECKING
15+ from typing import TYPE_CHECKING , Any , Mapping
1616
1717# Only import what's needed to create the DatabaseType enum
1818from .db .providers import get_supported_db_types as _get_supported_db_types
@@ -160,12 +160,6 @@ class ConnectionConfig:
160160 database : str = ""
161161 username : str = ""
162162 password : str | None = None
163- # SQL Server specific fields
164- auth_type : str = "sql"
165- driver : str = field (default_factory = _get_default_driver )
166- trusted_connection : bool = False
167- # SQLite specific fields
168- file_path : str = ""
169163 # SSH tunnel fields
170164 ssh_enabled : bool = False
171165 ssh_host : str = ""
@@ -174,35 +168,55 @@ class ConnectionConfig:
174168 ssh_auth_type : str = "key" # "key" or "password"
175169 ssh_password : str | None = None
176170 ssh_key_path : str = ""
177- # Supabase specific fields
178- supabase_region : str = ""
179- supabase_project_id : str = ""
180- # Oracle specific fields
181- oracle_role : str = "normal" # "normal", "sysdba", "sysoper"
182171 # Source tracking (e.g., "docker" for auto-detected containers)
183172 source : str | None = None
184173 # Original connection URL if created from URL
185174 connection_url : str | None = None
186175 # Extra options from URL query parameters (e.g., sslmode=require)
187176 extra_options : dict [str , str ] = field (default_factory = dict )
188-
189- def __post_init__ (self ) -> None :
190- """Handle backwards compatibility with old configs."""
191- # Old configs without db_type are SQL Server
192- if not hasattr (self , "db_type" ) or not self .db_type :
193- self .db_type = "mssql"
194-
195- # Apply default port for server-based DBs if missing (lazy import)
196- if not getattr (self , "port" , None ):
197- from .db .providers import get_default_port
198- default_port = get_default_port (self .db_type )
199- if default_port :
200- self .port = default_port
201-
202- # Handle old SQL Server auth compatibility
203- if self .db_type == "mssql" :
204- if self .auth_type == "windows" and not self .trusted_connection and self .username :
205- self .auth_type = "sql"
177+ # Provider-specific options (auth_type, driver, file_path, etc.)
178+ options : dict [str , Any ] = field (default_factory = dict )
179+
180+ @classmethod
181+ def from_dict (cls , data : Mapping [str , Any ]) -> ConnectionConfig :
182+ """Create a ConnectionConfig from a dict, with legacy key support."""
183+ payload = dict (data )
184+
185+ if "host" in payload and "server" not in payload :
186+ payload ["server" ] = payload .pop ("host" )
187+
188+ db_type = payload .get ("db_type" )
189+ if not isinstance (db_type , str ) or not db_type :
190+ payload ["db_type" ] = "mssql"
191+
192+ raw_options = payload .pop ("options" , None )
193+ options : dict [str , Any ] = {}
194+ if isinstance (raw_options , dict ):
195+ options .update (raw_options )
196+
197+ base_fields = {f .name for f in fields (cls )}
198+ for key in list (payload .keys ()):
199+ if key in base_fields :
200+ continue
201+ if key not in options :
202+ options [key ] = payload .pop (key )
203+ else :
204+ payload .pop (key )
205+
206+ payload ["options" ] = options
207+ return cls (** payload )
208+
209+ def get_option (self , name : str , default : Any | None = None ) -> Any :
210+ return self .options .get (name , default )
211+
212+ def set_option (self , name : str , value : Any ) -> None :
213+ self .options [name ] = value
214+
215+ def get_field_value (self , name : str , default : Any = "" ) -> Any :
216+ if name in self .__dataclass_fields__ :
217+ value = getattr (self , name )
218+ return value if value is not None else default
219+ return self .options .get (name , default )
206220
207221 def get_db_type (self ) -> DatabaseType :
208222 """Get the DatabaseType enum value."""
@@ -211,71 +225,6 @@ def get_db_type(self) -> DatabaseType:
211225 except ValueError :
212226 return DatabaseType .MSSQL # type: ignore[attr-defined, no-any-return]
213227
214- def get_auth_type (self ) -> AuthType :
215- """Get the AuthType enum value."""
216- try :
217- return AuthType (self .auth_type )
218- except ValueError :
219- return AuthType .SQL_SERVER
220-
221- def get_connection_string (self ) -> str :
222- """Build the connection string for SQL Server.
223-
224- .. deprecated::
225- This method is deprecated. Connection string building is now
226- handled internally by SQLServerAdapter._build_connection_string().
227- Use SQLServerAdapter.connect() directly instead.
228- """
229- import warnings
230-
231- warnings .warn (
232- "ConnectionConfig.get_connection_string() is deprecated. "
233- "Connection string building is now handled internally by SQLServerAdapter." ,
234- DeprecationWarning ,
235- stacklevel = 2 ,
236- )
237-
238- if self .db_type != "mssql" :
239- raise ValueError ("get_connection_string() is only for SQL Server connections" )
240-
241- server_with_port = self .server
242- if self .port and self .port != "1433" :
243- server_with_port = f"{ self .server } ,{ self .port } "
244-
245- base = (
246- f"DRIVER={{{ self .driver } }};"
247- f"SERVER={ server_with_port } ;"
248- f"DATABASE={ self .database or 'master' } ;"
249- f"TrustServerCertificate=yes;"
250- )
251-
252- auth = self .get_auth_type ()
253-
254- if auth == AuthType .WINDOWS :
255- return base + "Trusted_Connection=yes;"
256- elif auth == AuthType .SQL_SERVER :
257- return base + f"UID={ self .username } ;PWD={ self .password } ;"
258- elif auth == AuthType .AD_PASSWORD :
259- return base + f"Authentication=ActiveDirectoryPassword;" f"UID={ self .username } ;PWD={ self .password } ;"
260- elif auth == AuthType .AD_INTERACTIVE :
261- return base + f"Authentication=ActiveDirectoryInteractive;" f"UID={ self .username } ;"
262- elif auth == AuthType .AD_INTEGRATED :
263- return base + "Authentication=ActiveDirectoryIntegrated;"
264-
265- return base + "Trusted_Connection=yes;"
266-
267- def get_display_info (self ) -> str :
268- """Get a display string for the connection."""
269- from .db .providers import is_file_based
270- if is_file_based (self .db_type ):
271- return self .file_path or self .name
272-
273- if self .db_type == "supabase" :
274- return f"{ self .name } ({ self .supabase_region } )"
275-
276- db_part = f"@{ self .database } " if self .database else ""
277- return f"{ self .name } { db_part } "
278-
279228 def get_source_emoji (self ) -> str :
280229 """Get emoji indicator for connection source (e.g., '🐳 ' for docker)."""
281230 return get_source_emoji (self .source )
0 commit comments