1010
1111from __future__ import annotations
1212
13- from dataclasses import dataclass , field
13+ from dataclasses import dataclass , field , fields
1414from enum import Enum
15- from typing import TYPE_CHECKING
15+ from typing import TYPE_CHECKING , Any , Mapping
1616
1717# Only import what's needed to create the DatabaseType enum
1818from .db .providers import get_supported_db_types as _get_supported_db_types
@@ -161,12 +161,6 @@ class ConnectionConfig:
161161 database : str = ""
162162 username : str = ""
163163 password : str | None = None
164- # SQL Server specific fields
165- auth_type : str = "sql"
166- driver : str = field (default_factory = _get_default_driver )
167- trusted_connection : bool = False
168- # SQLite specific fields
169- file_path : str = ""
170164 # SSH tunnel fields
171165 ssh_enabled : bool = False
172166 ssh_host : str = ""
@@ -175,35 +169,55 @@ class ConnectionConfig:
175169 ssh_auth_type : str = "key" # "key" or "password"
176170 ssh_password : str | None = None
177171 ssh_key_path : str = ""
178- # Supabase specific fields
179- supabase_region : str = ""
180- supabase_project_id : str = ""
181- # Oracle specific fields
182- oracle_role : str = "normal" # "normal", "sysdba", "sysoper"
183172 # Source tracking (e.g., "docker" for auto-detected containers)
184173 source : str | None = None
185174 # Original connection URL if created from URL
186175 connection_url : str | None = None
187176 # Extra options from URL query parameters (e.g., sslmode=require)
188177 extra_options : dict [str , str ] = field (default_factory = dict )
189-
190- def __post_init__ (self ) -> None :
191- """Handle backwards compatibility with old configs."""
192- # Old configs without db_type are SQL Server
193- if not hasattr (self , "db_type" ) or not self .db_type :
194- self .db_type = "mssql"
195-
196- # Apply default port for server-based DBs if missing (lazy import)
197- if not getattr (self , "port" , None ):
198- from .db .providers import get_default_port
199- default_port = get_default_port (self .db_type )
200- if default_port :
201- self .port = default_port
202-
203- # Handle old SQL Server auth compatibility
204- if self .db_type == "mssql" :
205- if self .auth_type == "windows" and not self .trusted_connection and self .username :
206- self .auth_type = "sql"
178+ # Provider-specific options (auth_type, driver, file_path, etc.)
179+ options : dict [str , Any ] = field (default_factory = dict )
180+
181+ @classmethod
182+ def from_dict (cls , data : Mapping [str , Any ]) -> ConnectionConfig :
183+ """Create a ConnectionConfig from a dict, with legacy key support."""
184+ payload = dict (data )
185+
186+ if "host" in payload and "server" not in payload :
187+ payload ["server" ] = payload .pop ("host" )
188+
189+ db_type = payload .get ("db_type" )
190+ if not isinstance (db_type , str ) or not db_type :
191+ payload ["db_type" ] = "mssql"
192+
193+ raw_options = payload .pop ("options" , None )
194+ options : dict [str , Any ] = {}
195+ if isinstance (raw_options , dict ):
196+ options .update (raw_options )
197+
198+ base_fields = {f .name for f in fields (cls )}
199+ for key in list (payload .keys ()):
200+ if key in base_fields :
201+ continue
202+ if key not in options :
203+ options [key ] = payload .pop (key )
204+ else :
205+ payload .pop (key )
206+
207+ payload ["options" ] = options
208+ return cls (** payload )
209+
210+ def get_option (self , name : str , default : Any | None = None ) -> Any :
211+ return self .options .get (name , default )
212+
213+ def set_option (self , name : str , value : Any ) -> None :
214+ self .options [name ] = value
215+
216+ def get_field_value (self , name : str , default : Any = "" ) -> Any :
217+ if name in self .__dataclass_fields__ :
218+ value = getattr (self , name )
219+ return value if value is not None else default
220+ return self .options .get (name , default )
207221
208222 def get_db_type (self ) -> DatabaseType :
209223 """Get the DatabaseType enum value."""
@@ -212,71 +226,6 @@ def get_db_type(self) -> DatabaseType:
212226 except ValueError :
213227 return DatabaseType .MSSQL # type: ignore[attr-defined, no-any-return]
214228
215- def get_auth_type (self ) -> AuthType :
216- """Get the AuthType enum value."""
217- try :
218- return AuthType (self .auth_type )
219- except ValueError :
220- return AuthType .SQL_SERVER
221-
222- def get_connection_string (self ) -> str :
223- """Build the connection string for SQL Server.
224-
225- .. deprecated::
226- This method is deprecated. Connection string building is now
227- handled internally by SQLServerAdapter._build_connection_string().
228- Use SQLServerAdapter.connect() directly instead.
229- """
230- import warnings
231-
232- warnings .warn (
233- "ConnectionConfig.get_connection_string() is deprecated. "
234- "Connection string building is now handled internally by SQLServerAdapter." ,
235- DeprecationWarning ,
236- stacklevel = 2 ,
237- )
238-
239- if self .db_type != "mssql" :
240- raise ValueError ("get_connection_string() is only for SQL Server connections" )
241-
242- server_with_port = self .server
243- if self .port and self .port != "1433" :
244- server_with_port = f"{ self .server } ,{ self .port } "
245-
246- base = (
247- f"DRIVER={{{ self .driver } }};"
248- f"SERVER={ server_with_port } ;"
249- f"DATABASE={ self .database or 'master' } ;"
250- f"TrustServerCertificate=yes;"
251- )
252-
253- auth = self .get_auth_type ()
254-
255- if auth == AuthType .WINDOWS :
256- return base + "Trusted_Connection=yes;"
257- elif auth == AuthType .SQL_SERVER :
258- return base + f"UID={ self .username } ;PWD={ self .password } ;"
259- elif auth == AuthType .AD_PASSWORD :
260- return base + f"Authentication=ActiveDirectoryPassword;" f"UID={ self .username } ;PWD={ self .password } ;"
261- elif auth == AuthType .AD_INTERACTIVE :
262- return base + f"Authentication=ActiveDirectoryInteractive;" f"UID={ self .username } ;"
263- elif auth == AuthType .AD_INTEGRATED :
264- return base + "Authentication=ActiveDirectoryIntegrated;"
265-
266- return base + "Trusted_Connection=yes;"
267-
268- def get_display_info (self ) -> str :
269- """Get a display string for the connection."""
270- from .db .providers import is_file_based
271- if is_file_based (self .db_type ):
272- return self .file_path or self .name
273-
274- if self .db_type == "supabase" :
275- return f"{ self .name } ({ self .supabase_region } )"
276-
277- db_part = f"@{ self .database } " if self .database else ""
278- return f"{ self .name } { db_part } "
279-
280229 def get_source_emoji (self ) -> str :
281230 """Get emoji indicator for connection source (e.g., '🐳 ' for docker)."""
282231 return get_source_emoji (self .source )
0 commit comments