Skip to content

Commit 756e675

Browse files
committed
Refactor adapters - make it easier to make new ones
1 parent d5572e9 commit 756e675

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1227
-680
lines changed

sqlit/cli.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,10 @@ def _build_temp_connection(args: argparse.Namespace) -> ConnectionConfig | None:
387387
"""Build a temporary connection config from CLI args, if provided."""
388388
db_type = getattr(args, "db_type", None)
389389
file_path = getattr(args, "file_path", None)
390-
if not db_type and file_path:
391-
db_type = "sqlite"
392-
setattr(args, "db_type", db_type)
393390
if not db_type:
394-
if any(getattr(args, name, None) for name in ("file_path", "server", "host", "database")):
391+
if file_path:
392+
raise ValueError("--db-type is required when using --file-path")
393+
if any(getattr(args, name, None) for name in ("server", "host", "database")):
395394
raise ValueError("--db-type is required for temporary connections")
396395
return None
397396

sqlit/cli_helpers.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,24 @@
33
from __future__ import annotations
44

55
import argparse
6+
from dataclasses import fields
7+
from functools import lru_cache
68
from typing import Any, Iterable
79

810
from .config import ConnectionConfig
911
from .db.schema import ConnectionSchema, FieldType
1012

11-
CONNECTION_ARG_NAMES = {
12-
"name",
13-
"server",
14-
"host",
15-
"port",
16-
"database",
17-
"username",
18-
"password",
19-
"file_path",
20-
"auth_type",
21-
"supabase_region",
22-
"supabase_project_id",
23-
"ssh_enabled",
24-
"ssh_host",
25-
"ssh_port",
26-
"ssh_username",
27-
"ssh_auth_type",
28-
"ssh_key_path",
29-
"ssh_password",
30-
"driver",
31-
"oracle_role",
32-
}
13+
@lru_cache(maxsize=1)
14+
def _get_connection_arg_names() -> set[str]:
15+
from .db.schema import get_all_schemas
16+
17+
names = {
18+
field.name
19+
for schema in get_all_schemas().values()
20+
for field in schema.fields
21+
}
22+
names.add("host")
23+
return names
3324

3425

3526
def add_schema_arguments(
@@ -85,6 +76,8 @@ def build_connection_config_from_args(
8576
strict: bool = True,
8677
) -> ConnectionConfig:
8778
"""Build a ConnectionConfig from CLI args based on a provider schema."""
79+
from .db.providers import normalize_connection_config
80+
8881
raw_values = _extract_raw_values(schema, args)
8982

9083
missing = _find_missing_required_fields(schema, raw_values)
@@ -99,10 +92,12 @@ def build_connection_config_from_args(
9992
raise ValueError(f"Unexpected fields for {schema.display_name}: {extras_args}")
10093

10194
config_name = name or default_name or f"Temp {schema.display_name}"
102-
config_values = {
95+
config_values: dict[str, Any] = {
10396
"name": config_name,
10497
"db_type": schema.db_type,
10598
}
99+
options: dict[str, Any] = {}
100+
base_fields = {f.name for f in fields(ConnectionConfig)}
106101

107102
# Fields where None means "not set" vs "" means "explicitly empty"
108103
nullable_fields = {"password", "ssh_password"}
@@ -113,21 +108,27 @@ def build_connection_config_from_args(
113108
value = ""
114109
if field.name == "ssh_enabled":
115110
if isinstance(value, bool):
116-
config_values[field.name] = value
111+
normalized = value
117112
else:
118-
config_values[field.name] = str(value).lower() == "enabled"
119-
else:
113+
normalized = str(value).lower() == "enabled"
114+
config_values[field.name] = normalized
115+
continue
116+
117+
if field.name in base_fields:
120118
config_values[field.name] = value
119+
else:
120+
options[field.name] = value
121121

122122
if "port" in config_values and not config_values["port"]:
123123
config_values["port"] = schema.default_port or ""
124124

125125
if schema.has_advanced_auth:
126-
auth_type = config_values.get("auth_type") or "sql"
127-
config_values["auth_type"] = auth_type
128-
config_values["trusted_connection"] = auth_type == "windows"
126+
auth_type = options.get("auth_type") or "sql"
127+
options["auth_type"] = auth_type
128+
options["trusted_connection"] = auth_type == "windows"
129129

130-
return ConnectionConfig(**config_values)
130+
config_values["options"] = options
131+
return normalize_connection_config(ConnectionConfig(**config_values))
131132

132133

133134
def _extract_raw_values(schema: ConnectionSchema, args: Any) -> dict[str, Any]:
@@ -158,7 +159,7 @@ def _find_missing_required_fields(schema: ConnectionSchema, raw_values: dict[str
158159
def _find_unexpected_fields(schema: ConnectionSchema, args: Any) -> list[str]:
159160
allowed = {field.name for field in schema.fields}
160161
extras: list[str] = []
161-
for field in CONNECTION_ARG_NAMES:
162+
for field in _get_connection_arg_names():
162163
if field in allowed or field == "name":
163164
continue
164165
value = getattr(args, field, None)

sqlit/commands.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
load_connections,
2121
save_connections,
2222
)
23-
from .db.providers import get_connection_schema, has_advanced_auth, is_file_based
23+
from .db.providers import get_adapter_class, get_connection_schema, has_advanced_auth, is_file_based
2424
from .services import ConnectionSession, QueryResult, QueryService
2525
from .services.credentials import (
2626
ALLOW_PLAINTEXT_CREDENTIALS_SETTING,
@@ -101,12 +101,16 @@ def cmd_connection_list(args: Any) -> int:
101101
for conn in connections:
102102
db_type_label = labels.get(conn.get_db_type(), conn.db_type)
103103
if is_file_based(conn.db_type):
104-
conn_info = conn.file_path[:38] + ".." if len(conn.file_path) > 40 else conn.file_path
104+
file_path = str(conn.get_option("file_path", ""))
105+
conn_info = file_path[:38] + ".." if len(file_path) > 40 else file_path
105106
auth_label = "N/A"
106107
elif has_advanced_auth(conn.db_type):
107108
conn_info = f"{conn.server}@{conn.database}" if conn.database else conn.server
108109
conn_info = conn_info[:38] + ".." if len(conn_info) > 40 else conn_info
109-
auth_label = AUTH_TYPE_LABELS.get(conn.get_auth_type(), conn.auth_type)
110+
auth_value = str(conn.get_option("auth_type", ""))
111+
adapter_class = get_adapter_class(conn.db_type)
112+
auth_type = adapter_class().get_auth_type(conn)
113+
auth_label = AUTH_TYPE_LABELS.get(auth_type, auth_value) if auth_type else auth_value
110114
else:
111115
# Server-based databases with simple auth
112116
conn_info = f"{conn.server}@{conn.database}" if conn.database else conn.server
@@ -226,8 +230,8 @@ def cmd_connection_edit(args: Any) -> int:
226230
if args.auth_type:
227231
try:
228232
auth_type = AuthType(args.auth_type)
229-
conn.auth_type = auth_type.value
230-
conn.trusted_connection = auth_type == AuthType.WINDOWS
233+
conn.set_option("auth_type", auth_type.value)
234+
conn.set_option("trusted_connection", auth_type == AuthType.WINDOWS)
231235
except ValueError:
232236
valid_types = ", ".join(t.value for t in AuthType)
233237
print(f"Error: Invalid auth type '{args.auth_type}'. Valid types: {valid_types}")
@@ -239,7 +243,7 @@ def cmd_connection_edit(args: Any) -> int:
239243

240244
file_path = getattr(args, "file_path", None)
241245
if file_path is not None:
242-
conn.file_path = file_path
246+
conn.set_option("file_path", file_path)
243247

244248
if (conn.password or conn.ssh_password) and not is_keyring_usable():
245249
if not _maybe_prompt_plaintext_credentials():
@@ -371,8 +375,9 @@ def cmd_query(
371375
print(f"Error: Connection '{args.connection}' not found.")
372376
return 1
373377

374-
if args.database and config.db_type == "mssql":
375-
config = replace(config, database=args.database)
378+
if args.database:
379+
adapter_class = get_adapter_class(config.db_type)
380+
config = adapter_class().apply_database_override(config, args.database)
376381

377382
config = _prompt_for_password(config)
378383

sqlit/config.py

Lines changed: 45 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
from __future__ import annotations
1212

13-
from dataclasses import dataclass, field
13+
from dataclasses import dataclass, field, fields
1414
from enum import Enum
15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Any, Mapping
1616

1717
# Only import what's needed to create the DatabaseType enum
1818
from .db.providers import get_supported_db_types as _get_supported_db_types
@@ -160,12 +160,6 @@ class ConnectionConfig:
160160
database: str = ""
161161
username: str = ""
162162
password: str | None = None
163-
# SQL Server specific fields
164-
auth_type: str = "sql"
165-
driver: str = field(default_factory=_get_default_driver)
166-
trusted_connection: bool = False
167-
# SQLite specific fields
168-
file_path: str = ""
169163
# SSH tunnel fields
170164
ssh_enabled: bool = False
171165
ssh_host: str = ""
@@ -174,35 +168,55 @@ class ConnectionConfig:
174168
ssh_auth_type: str = "key" # "key" or "password"
175169
ssh_password: str | None = None
176170
ssh_key_path: str = ""
177-
# Supabase specific fields
178-
supabase_region: str = ""
179-
supabase_project_id: str = ""
180-
# Oracle specific fields
181-
oracle_role: str = "normal" # "normal", "sysdba", "sysoper"
182171
# Source tracking (e.g., "docker" for auto-detected containers)
183172
source: str | None = None
184173
# Original connection URL if created from URL
185174
connection_url: str | None = None
186175
# Extra options from URL query parameters (e.g., sslmode=require)
187176
extra_options: dict[str, str] = field(default_factory=dict)
188-
189-
def __post_init__(self) -> None:
190-
"""Handle backwards compatibility with old configs."""
191-
# Old configs without db_type are SQL Server
192-
if not hasattr(self, "db_type") or not self.db_type:
193-
self.db_type = "mssql"
194-
195-
# Apply default port for server-based DBs if missing (lazy import)
196-
if not getattr(self, "port", None):
197-
from .db.providers import get_default_port
198-
default_port = get_default_port(self.db_type)
199-
if default_port:
200-
self.port = default_port
201-
202-
# Handle old SQL Server auth compatibility
203-
if self.db_type == "mssql":
204-
if self.auth_type == "windows" and not self.trusted_connection and self.username:
205-
self.auth_type = "sql"
177+
# Provider-specific options (auth_type, driver, file_path, etc.)
178+
options: dict[str, Any] = field(default_factory=dict)
179+
180+
@classmethod
181+
def from_dict(cls, data: Mapping[str, Any]) -> ConnectionConfig:
182+
"""Create a ConnectionConfig from a dict, with legacy key support."""
183+
payload = dict(data)
184+
185+
if "host" in payload and "server" not in payload:
186+
payload["server"] = payload.pop("host")
187+
188+
db_type = payload.get("db_type")
189+
if not isinstance(db_type, str) or not db_type:
190+
payload["db_type"] = "mssql"
191+
192+
raw_options = payload.pop("options", None)
193+
options: dict[str, Any] = {}
194+
if isinstance(raw_options, dict):
195+
options.update(raw_options)
196+
197+
base_fields = {f.name for f in fields(cls)}
198+
for key in list(payload.keys()):
199+
if key in base_fields:
200+
continue
201+
if key not in options:
202+
options[key] = payload.pop(key)
203+
else:
204+
payload.pop(key)
205+
206+
payload["options"] = options
207+
return cls(**payload)
208+
209+
def get_option(self, name: str, default: Any | None = None) -> Any:
210+
return self.options.get(name, default)
211+
212+
def set_option(self, name: str, value: Any) -> None:
213+
self.options[name] = value
214+
215+
def get_field_value(self, name: str, default: Any = "") -> Any:
216+
if name in self.__dataclass_fields__:
217+
value = getattr(self, name)
218+
return value if value is not None else default
219+
return self.options.get(name, default)
206220

207221
def get_db_type(self) -> DatabaseType:
208222
"""Get the DatabaseType enum value."""
@@ -211,71 +225,6 @@ def get_db_type(self) -> DatabaseType:
211225
except ValueError:
212226
return DatabaseType.MSSQL # type: ignore[attr-defined, no-any-return]
213227

214-
def get_auth_type(self) -> AuthType:
215-
"""Get the AuthType enum value."""
216-
try:
217-
return AuthType(self.auth_type)
218-
except ValueError:
219-
return AuthType.SQL_SERVER
220-
221-
def get_connection_string(self) -> str:
222-
"""Build the connection string for SQL Server.
223-
224-
.. deprecated::
225-
This method is deprecated. Connection string building is now
226-
handled internally by SQLServerAdapter._build_connection_string().
227-
Use SQLServerAdapter.connect() directly instead.
228-
"""
229-
import warnings
230-
231-
warnings.warn(
232-
"ConnectionConfig.get_connection_string() is deprecated. "
233-
"Connection string building is now handled internally by SQLServerAdapter.",
234-
DeprecationWarning,
235-
stacklevel=2,
236-
)
237-
238-
if self.db_type != "mssql":
239-
raise ValueError("get_connection_string() is only for SQL Server connections")
240-
241-
server_with_port = self.server
242-
if self.port and self.port != "1433":
243-
server_with_port = f"{self.server},{self.port}"
244-
245-
base = (
246-
f"DRIVER={{{self.driver}}};"
247-
f"SERVER={server_with_port};"
248-
f"DATABASE={self.database or 'master'};"
249-
f"TrustServerCertificate=yes;"
250-
)
251-
252-
auth = self.get_auth_type()
253-
254-
if auth == AuthType.WINDOWS:
255-
return base + "Trusted_Connection=yes;"
256-
elif auth == AuthType.SQL_SERVER:
257-
return base + f"UID={self.username};PWD={self.password};"
258-
elif auth == AuthType.AD_PASSWORD:
259-
return base + f"Authentication=ActiveDirectoryPassword;" f"UID={self.username};PWD={self.password};"
260-
elif auth == AuthType.AD_INTERACTIVE:
261-
return base + f"Authentication=ActiveDirectoryInteractive;" f"UID={self.username};"
262-
elif auth == AuthType.AD_INTEGRATED:
263-
return base + "Authentication=ActiveDirectoryIntegrated;"
264-
265-
return base + "Trusted_Connection=yes;"
266-
267-
def get_display_info(self) -> str:
268-
"""Get a display string for the connection."""
269-
from .db.providers import is_file_based
270-
if is_file_based(self.db_type):
271-
return self.file_path or self.name
272-
273-
if self.db_type == "supabase":
274-
return f"{self.name} ({self.supabase_region})"
275-
276-
db_part = f"@{self.database}" if self.database else ""
277-
return f"{self.name}{db_part}"
278-
279228
def get_source_emoji(self) -> str:
280229
"""Get emoji indicator for connection source (e.g., '🐳 ' for docker)."""
281230
return get_source_emoji(self.source)

0 commit comments

Comments
 (0)