Skip to content

Commit a215efb

Browse files
committed
Refactor to make adapter file describes themselves entirely
2 parents 0ac89b2 + 756e675 commit a215efb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1271
-695
lines changed

sqlit/cli.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,10 @@ def _build_temp_connection(args: argparse.Namespace) -> ConnectionConfig | None:
387387
"""Build a temporary connection config from CLI args, if provided."""
388388
db_type = getattr(args, "db_type", None)
389389
file_path = getattr(args, "file_path", None)
390-
if not db_type and file_path:
391-
db_type = "sqlite"
392-
setattr(args, "db_type", db_type)
393390
if not db_type:
394-
if any(getattr(args, name, None) for name in ("file_path", "server", "host", "database")):
391+
if file_path:
392+
raise ValueError("--db-type is required when using --file-path")
393+
if any(getattr(args, name, None) for name in ("server", "host", "database")):
395394
raise ValueError("--db-type is required for temporary connections")
396395
return None
397396

sqlit/cli_helpers.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,24 @@
33
from __future__ import annotations
44

55
import argparse
6+
from dataclasses import fields
7+
from functools import lru_cache
68
from typing import Any, Iterable
79

810
from .config import ConnectionConfig
911
from .db.schema import ConnectionSchema, FieldType
1012

11-
CONNECTION_ARG_NAMES = {
12-
"name",
13-
"server",
14-
"host",
15-
"port",
16-
"database",
17-
"username",
18-
"password",
19-
"file_path",
20-
"auth_type",
21-
"supabase_region",
22-
"supabase_project_id",
23-
"ssh_enabled",
24-
"ssh_host",
25-
"ssh_port",
26-
"ssh_username",
27-
"ssh_auth_type",
28-
"ssh_key_path",
29-
"ssh_password",
30-
"driver",
31-
"oracle_role",
32-
}
13+
@lru_cache(maxsize=1)
14+
def _get_connection_arg_names() -> set[str]:
15+
from .db.schema import get_all_schemas
16+
17+
names = {
18+
field.name
19+
for schema in get_all_schemas().values()
20+
for field in schema.fields
21+
}
22+
names.add("host")
23+
return names
3324

3425

3526
def add_schema_arguments(
@@ -85,6 +76,8 @@ def build_connection_config_from_args(
8576
strict: bool = True,
8677
) -> ConnectionConfig:
8778
"""Build a ConnectionConfig from CLI args based on a provider schema."""
79+
from .db.providers import normalize_connection_config
80+
8881
raw_values = _extract_raw_values(schema, args)
8982

9083
missing = _find_missing_required_fields(schema, raw_values)
@@ -99,10 +92,12 @@ def build_connection_config_from_args(
9992
raise ValueError(f"Unexpected fields for {schema.display_name}: {extras_args}")
10093

10194
config_name = name or default_name or f"Temp {schema.display_name}"
102-
config_values = {
95+
config_values: dict[str, Any] = {
10396
"name": config_name,
10497
"db_type": schema.db_type,
10598
}
99+
options: dict[str, Any] = {}
100+
base_fields = {f.name for f in fields(ConnectionConfig)}
106101

107102
# Fields where None means "not set" vs "" means "explicitly empty"
108103
nullable_fields = {"password", "ssh_password"}
@@ -113,21 +108,27 @@ def build_connection_config_from_args(
113108
value = ""
114109
if field.name == "ssh_enabled":
115110
if isinstance(value, bool):
116-
config_values[field.name] = value
111+
normalized = value
117112
else:
118-
config_values[field.name] = str(value).lower() == "enabled"
119-
else:
113+
normalized = str(value).lower() == "enabled"
114+
config_values[field.name] = normalized
115+
continue
116+
117+
if field.name in base_fields:
120118
config_values[field.name] = value
119+
else:
120+
options[field.name] = value
121121

122122
if "port" in config_values and not config_values["port"]:
123123
config_values["port"] = schema.default_port or ""
124124

125125
if schema.has_advanced_auth:
126-
auth_type = config_values.get("auth_type") or "sql"
127-
config_values["auth_type"] = auth_type
128-
config_values["trusted_connection"] = auth_type == "windows"
126+
auth_type = options.get("auth_type") or "sql"
127+
options["auth_type"] = auth_type
128+
options["trusted_connection"] = auth_type == "windows"
129129

130-
return ConnectionConfig(**config_values)
130+
config_values["options"] = options
131+
return normalize_connection_config(ConnectionConfig(**config_values))
131132

132133

133134
def _extract_raw_values(schema: ConnectionSchema, args: Any) -> dict[str, Any]:
@@ -158,7 +159,7 @@ def _find_missing_required_fields(schema: ConnectionSchema, raw_values: dict[str
158159
def _find_unexpected_fields(schema: ConnectionSchema, args: Any) -> list[str]:
159160
allowed = {field.name for field in schema.fields}
160161
extras: list[str] = []
161-
for field in CONNECTION_ARG_NAMES:
162+
for field in _get_connection_arg_names():
162163
if field in allowed or field == "name":
163164
continue
164165
value = getattr(args, field, None)

sqlit/commands.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
load_connections,
2121
save_connections,
2222
)
23-
from .db.providers import get_connection_schema, has_advanced_auth, is_file_based
23+
from .db.providers import get_adapter_class, get_connection_schema, has_advanced_auth, is_file_based
2424
from .services import ConnectionSession, QueryResult, QueryService
2525
from .services.credentials import (
2626
ALLOW_PLAINTEXT_CREDENTIALS_SETTING,
@@ -101,12 +101,16 @@ def cmd_connection_list(args: Any) -> int:
101101
for conn in connections:
102102
db_type_label = labels.get(conn.get_db_type(), conn.db_type)
103103
if is_file_based(conn.db_type):
104-
conn_info = conn.file_path[:38] + ".." if len(conn.file_path) > 40 else conn.file_path
104+
file_path = str(conn.get_option("file_path", ""))
105+
conn_info = file_path[:38] + ".." if len(file_path) > 40 else file_path
105106
auth_label = "N/A"
106107
elif has_advanced_auth(conn.db_type):
107108
conn_info = f"{conn.server}@{conn.database}" if conn.database else conn.server
108109
conn_info = conn_info[:38] + ".." if len(conn_info) > 40 else conn_info
109-
auth_label = AUTH_TYPE_LABELS.get(conn.get_auth_type(), conn.auth_type)
110+
auth_value = str(conn.get_option("auth_type", ""))
111+
adapter_class = get_adapter_class(conn.db_type)
112+
auth_type = adapter_class().get_auth_type(conn)
113+
auth_label = AUTH_TYPE_LABELS.get(auth_type, auth_value) if auth_type else auth_value
110114
else:
111115
# Server-based databases with simple auth
112116
conn_info = f"{conn.server}@{conn.database}" if conn.database else conn.server
@@ -226,8 +230,8 @@ def cmd_connection_edit(args: Any) -> int:
226230
if args.auth_type:
227231
try:
228232
auth_type = AuthType(args.auth_type)
229-
conn.auth_type = auth_type.value
230-
conn.trusted_connection = auth_type == AuthType.WINDOWS
233+
conn.set_option("auth_type", auth_type.value)
234+
conn.set_option("trusted_connection", auth_type == AuthType.WINDOWS)
231235
except ValueError:
232236
valid_types = ", ".join(t.value for t in AuthType)
233237
print(f"Error: Invalid auth type '{args.auth_type}'. Valid types: {valid_types}")
@@ -239,7 +243,7 @@ def cmd_connection_edit(args: Any) -> int:
239243

240244
file_path = getattr(args, "file_path", None)
241245
if file_path is not None:
242-
conn.file_path = file_path
246+
conn.set_option("file_path", file_path)
243247

244248
if (conn.password or conn.ssh_password) and not is_keyring_usable():
245249
if not _maybe_prompt_plaintext_credentials():
@@ -371,8 +375,9 @@ def cmd_query(
371375
print(f"Error: Connection '{args.connection}' not found.")
372376
return 1
373377

374-
if args.database and config.db_type == "mssql":
375-
config = replace(config, database=args.database)
378+
if args.database:
379+
adapter_class = get_adapter_class(config.db_type)
380+
config = adapter_class().apply_database_override(config, args.database)
376381

377382
config = _prompt_for_password(config)
378383

sqlit/config.py

Lines changed: 45 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
from __future__ import annotations
1212

13-
from dataclasses import dataclass, field
13+
from dataclasses import dataclass, field, fields
1414
from enum import Enum
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Any, Mapping
1616

1717
# Only import what's needed to create the DatabaseType enum
1818
from .db.providers import get_supported_db_types as _get_supported_db_types
@@ -161,12 +161,6 @@ class ConnectionConfig:
161161
database: str = ""
162162
username: str = ""
163163
password: str | None = None
164-
# SQL Server specific fields
165-
auth_type: str = "sql"
166-
driver: str = field(default_factory=_get_default_driver)
167-
trusted_connection: bool = False
168-
# SQLite specific fields
169-
file_path: str = ""
170164
# SSH tunnel fields
171165
ssh_enabled: bool = False
172166
ssh_host: str = ""
@@ -175,35 +169,55 @@ class ConnectionConfig:
175169
ssh_auth_type: str = "key" # "key" or "password"
176170
ssh_password: str | None = None
177171
ssh_key_path: str = ""
178-
# Supabase specific fields
179-
supabase_region: str = ""
180-
supabase_project_id: str = ""
181-
# Oracle specific fields
182-
oracle_role: str = "normal" # "normal", "sysdba", "sysoper"
183172
# Source tracking (e.g., "docker" for auto-detected containers)
184173
source: str | None = None
185174
# Original connection URL if created from URL
186175
connection_url: str | None = None
187176
# Extra options from URL query parameters (e.g., sslmode=require)
188177
extra_options: dict[str, str] = field(default_factory=dict)
189-
190-
def __post_init__(self) -> None:
191-
"""Handle backwards compatibility with old configs."""
192-
# Old configs without db_type are SQL Server
193-
if not hasattr(self, "db_type") or not self.db_type:
194-
self.db_type = "mssql"
195-
196-
# Apply default port for server-based DBs if missing (lazy import)
197-
if not getattr(self, "port", None):
198-
from .db.providers import get_default_port
199-
default_port = get_default_port(self.db_type)
200-
if default_port:
201-
self.port = default_port
202-
203-
# Handle old SQL Server auth compatibility
204-
if self.db_type == "mssql":
205-
if self.auth_type == "windows" and not self.trusted_connection and self.username:
206-
self.auth_type = "sql"
178+
# Provider-specific options (auth_type, driver, file_path, etc.)
179+
options: dict[str, Any] = field(default_factory=dict)
180+
181+
@classmethod
182+
def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig:
183+
"""Create a ConnectionConfig from a dict, with legacy key support."""
184+
payload = dict(data)
185+
186+
if "host" in payload and "server" not in payload:
187+
payload["server"] = payload.pop("host")
188+
189+
db_type = payload.get("db_type")
190+
if not isinstance(db_type, str) or not db_type:
191+
payload["db_type"] = "mssql"
192+
193+
raw_options = payload.pop("options", None)
194+
options: dict[str, Any] = {}
195+
if isinstance(raw_options, dict):
196+
options.update(raw_options)
197+
198+
base_fields = {f.name for f in fields(cls)}
199+
for key in list(payload.keys()):
200+
if key in base_fields:
201+
continue
202+
if key not in options:
203+
options[key] = payload.pop(key)
204+
else:
205+
payload.pop(key)
206+
207+
payload["options"] = options
208+
return cls(**payload)
209+
210+
def get_option(self, name: str, default: Any | None = None) -> Any:
211+
return self.options.get(name, default)
212+
213+
def set_option(self, name: str, value: Any) -> None:
214+
self.options[name] = value
215+
216+
def get_field_value(self, name: str, default: Any = "") -> Any:
217+
if name in self.__dataclass_fields__:
218+
value = getattr(self, name)
219+
return value if value is not None else default
220+
return self.options.get(name, default)
207221

208222
def get_db_type(self) -> DatabaseType:
209223
"""Get the DatabaseType enum value."""
@@ -212,71 +226,6 @@ def get_db_type(self) -> DatabaseType:
212226
except ValueError:
213227
return DatabaseType.MSSQL # type: ignore[attr-defined, no-any-return]
214228

215-
def get_auth_type(self) -> AuthType:
216-
"""Get the AuthType enum value."""
217-
try:
218-
return AuthType(self.auth_type)
219-
except ValueError:
220-
return AuthType.SQL_SERVER
221-
222-
def get_connection_string(self) -> str:
223-
"""Build the connection string for SQL Server.
224-
225-
.. deprecated::
226-
This method is deprecated. Connection string building is now
227-
handled internally by SQLServerAdapter._build_connection_string().
228-
Use SQLServerAdapter.connect() directly instead.
229-
"""
230-
import warnings
231-
232-
warnings.warn(
233-
"ConnectionConfig.get_connection_string() is deprecated. "
234-
"Connection string building is now handled internally by SQLServerAdapter.",
235-
DeprecationWarning,
236-
stacklevel=2,
237-
)
238-
239-
if self.db_type != "mssql":
240-
raise ValueError("get_connection_string() is only for SQL Server connections")
241-
242-
server_with_port = self.server
243-
if self.port and self.port != "1433":
244-
server_with_port = f"{self.server},{self.port}"
245-
246-
base = (
247-
f"DRIVER={{{self.driver}}};"
248-
f"SERVER={server_with_port};"
249-
f"DATABASE={self.database or 'master'};"
250-
f"TrustServerCertificate=yes;"
251-
)
252-
253-
auth = self.get_auth_type()
254-
255-
if auth == AuthType.WINDOWS:
256-
return base + "Trusted_Connection=yes;"
257-
elif auth == AuthType.SQL_SERVER:
258-
return base + f"UID={self.username};PWD={self.password};"
259-
elif auth == AuthType.AD_PASSWORD:
260-
return base + f"Authentication=ActiveDirectoryPassword;" f"UID={self.username};PWD={self.password};"
261-
elif auth == AuthType.AD_INTERACTIVE:
262-
return base + f"Authentication=ActiveDirectoryInteractive;" f"UID={self.username};"
263-
elif auth == AuthType.AD_INTEGRATED:
264-
return base + "Authentication=ActiveDirectoryIntegrated;"
265-
266-
return base + "Trusted_Connection=yes;"
267-
268-
def get_display_info(self) -> str:
269-
"""Get a display string for the connection."""
270-
from .db.providers import is_file_based
271-
if is_file_based(self.db_type):
272-
return self.file_path or self.name
273-
274-
if self.db_type == "supabase":
275-
return f"{self.name} ({self.supabase_region})"
276-
277-
db_part = f"@{self.database}" if self.database else ""
278-
return f"{self.name}{db_part}"
279-
280229
def get_source_emoji(self) -> str:
281230
"""Get emoji indicator for connection source (e.g., '🐳 ' for docker)."""
282231
return get_source_emoji(self.source)

0 commit comments

Comments
 (0)