Skip to content

Commit 1c0efb0

Browse files
committed
Clarify missing driving warnings for import failures
1 parent 85a7ed1 commit 1c0efb0

File tree

14 files changed

+170
-100
lines changed

14 files changed

+170
-100
lines changed

sqlit/db/adapters/base.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import importlib
6+
import importlib.util
67
from abc import ABC, abstractmethod
78
from dataclasses import dataclass
89
from pathlib import Path
@@ -74,6 +75,36 @@ class SequenceInfo:
7475
TableInfo = tuple[str, str]
7576

7677

78+
def import_driver_module(
79+
module_name: str,
80+
*,
81+
driver_name: str,
82+
extra_name: str | None,
83+
package_name: str | None,
84+
) -> Any:
85+
"""Import a driver module, raising MissingDriverError with detail if it fails."""
86+
if not extra_name or not package_name:
87+
return importlib.import_module(module_name)
88+
89+
if importlib.util.find_spec(module_name) is None:
90+
from ...db.exceptions import MissingDriverError
91+
92+
raise MissingDriverError(driver_name, extra_name, package_name, module_name=module_name)
93+
94+
try:
95+
return importlib.import_module(module_name)
96+
except ImportError as e:
97+
from ...db.exceptions import MissingDriverError
98+
99+
raise MissingDriverError(
100+
driver_name,
101+
extra_name,
102+
package_name,
103+
module_name=module_name,
104+
import_error=str(e),
105+
) from e
106+
107+
77108
class DatabaseAdapter(ABC):
78109
"""Abstract base class for database adapters.
79110
@@ -97,15 +128,13 @@ def ensure_driver_available(self) -> None:
97128
"""Verify required dependencies can be imported, raising MissingDriverError if not."""
98129
if not self.driver_import_names:
99130
return
100-
try:
101-
for module_name in self.driver_import_names:
102-
importlib.import_module(module_name)
103-
except ImportError as e:
104-
from ...db.exceptions import MissingDriverError
105-
106-
if not self.install_extra or not self.install_package:
107-
raise e
108-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
131+
for module_name in self.driver_import_names:
132+
import_driver_module(
133+
module_name,
134+
driver_name=self.name,
135+
extra_name=self.install_extra,
136+
package_name=self.install_package,
137+
)
109138

110139
@property
111140
def install_extra(self) -> str | None:

sqlit/db/adapters/cockroachdb.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING, Any
66

77
from ..schema import get_default_port
8-
from .base import PostgresBaseAdapter
8+
from .base import PostgresBaseAdapter, import_driver_module
99

1010
if TYPE_CHECKING:
1111
from ...config import ConnectionConfig
@@ -41,14 +41,12 @@ def supports_triggers(self) -> bool:
4141

4242
def connect(self, config: ConnectionConfig) -> Any:
4343
"""Connect to CockroachDB database."""
44-
try:
45-
import psycopg2
46-
except ImportError as e:
47-
from ...db.exceptions import MissingDriverError
48-
49-
if not self.install_extra or not self.install_package:
50-
raise e
51-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
44+
psycopg2 = import_driver_module(
45+
"psycopg2",
46+
driver_name=self.name,
47+
extra_name=self.install_extra,
48+
package_name=self.install_package,
49+
)
5250

5351
port = int(config.port or get_default_port("cockroachdb"))
5452
conn = psycopg2.connect(

sqlit/db/adapters/d1.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from typing import TYPE_CHECKING, Any, cast
88

9-
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo
9+
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo, import_driver_module
1010

1111
if TYPE_CHECKING:
1212
import requests
@@ -58,14 +58,12 @@ def supports_stored_procedures(self) -> bool:
5858

5959
def connect(self, config: ConnectionConfig) -> D1Connection:
6060
"""Establishes a 'connection' to D1 by preparing authenticated session."""
61-
try:
62-
import requests
63-
except ImportError as e:
64-
from ...db.exceptions import MissingDriverError
65-
66-
if not self.install_extra or not self.install_package:
67-
raise e
68-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
61+
requests = import_driver_module(
62+
"requests",
63+
driver_name=self.name,
64+
extra_name=self.install_extra,
65+
package_name=self.install_package,
66+
)
6967

7068
session = requests.Session()
7169
session.headers.update({"Authorization": f"Bearer {config.password}"})

sqlit/db/adapters/duckdb.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@
44

55
from typing import TYPE_CHECKING, Any
66

7-
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo, resolve_file_path
7+
from .base import (
8+
ColumnInfo,
9+
DatabaseAdapter,
10+
IndexInfo,
11+
SequenceInfo,
12+
TableInfo,
13+
TriggerInfo,
14+
import_driver_module,
15+
resolve_file_path,
16+
)
817

918
if TYPE_CHECKING:
1019
from ...config import ConnectionConfig
@@ -58,14 +67,12 @@ def connect(self, config: ConnectionConfig) -> Any:
5867
serialized via exclusive workers to ensure only one thread accesses
5968
the connection at a time.
6069
"""
61-
try:
62-
import duckdb
63-
except ImportError as e:
64-
from ...db.exceptions import MissingDriverError
65-
66-
if not self.install_extra or not self.install_package:
67-
raise e
68-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
70+
duckdb = import_driver_module(
71+
"duckdb",
72+
driver_name=self.name,
73+
extra_name=self.install_extra,
74+
package_name=self.install_package,
75+
)
6976

7077
file_path = resolve_file_path(config.file_path)
7178
duckdb_any: Any = duckdb

sqlit/db/adapters/mariadb.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING, Any
77

88
from ..schema import get_default_port
9-
from .base import ColumnInfo, IndexInfo, MySQLBaseAdapter, SequenceInfo, TableInfo, TriggerInfo
9+
from .base import ColumnInfo, IndexInfo, MySQLBaseAdapter, SequenceInfo, TableInfo, TriggerInfo, import_driver_module
1010

1111
if TYPE_CHECKING:
1212
from ...config import ConnectionConfig
@@ -42,14 +42,12 @@ def supports_sequences(self) -> bool:
4242

4343
def connect(self, config: ConnectionConfig) -> Any:
4444
"""Connect to MariaDB database."""
45-
try:
46-
import mariadb
47-
except ImportError as e:
48-
from ...db.exceptions import MissingDriverError
49-
50-
if not self.install_extra or not self.install_package:
51-
raise e
52-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
45+
mariadb = import_driver_module(
46+
"mariadb",
47+
driver_name=self.name,
48+
extra_name=self.install_extra,
49+
package_name=self.install_package,
50+
)
5351

5452
port = int(config.port or get_default_port("mariadb"))
5553
mariadb_any: Any = mariadb

sqlit/db/adapters/mssql.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import TYPE_CHECKING, Any
66

7-
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo
7+
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo, import_driver_module
88

99
if TYPE_CHECKING:
1010
from ...config import ConnectionConfig
@@ -85,14 +85,12 @@ def _build_connection_string(self, config: ConnectionConfig) -> str:
8585

8686
def connect(self, config: ConnectionConfig) -> Any:
8787
"""Connect to SQL Server using pyodbc."""
88-
try:
89-
import pyodbc
90-
except ImportError as e:
91-
from ...db.exceptions import MissingDriverError
92-
93-
if not self.install_extra or not self.install_package:
94-
raise e
95-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
88+
pyodbc = import_driver_module(
89+
"pyodbc",
90+
driver_name=self.name,
91+
extra_name=self.install_extra,
92+
package_name=self.install_package,
93+
)
9694

9795
installed = list(pyodbc.drivers())
9896
if config.driver not in installed:

sqlit/db/adapters/mysql.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING, Any
66

77
from ..schema import get_default_port
8-
from .base import MySQLBaseAdapter
8+
from .base import MySQLBaseAdapter, import_driver_module
99

1010
if TYPE_CHECKING:
1111
from ...config import ConnectionConfig
@@ -32,17 +32,15 @@ def driver_import_names(self) -> tuple[str, ...]:
3232

3333
def connect(self, config: ConnectionConfig) -> Any:
3434
"""Connect to MySQL database."""
35-
try:
36-
import mysql.connector
37-
except ImportError as e:
38-
from ...db.exceptions import MissingDriverError
39-
40-
if not self.install_extra or not self.install_package:
41-
raise e
42-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
35+
mysql_connector = import_driver_module(
36+
"mysql.connector",
37+
driver_name=self.name,
38+
extra_name=self.install_extra,
39+
package_name=self.install_package,
40+
)
4341

4442
port = int(config.port or get_default_port("mysql"))
45-
return mysql.connector.connect(
43+
return mysql_connector.connect(
4644
host=config.server,
4745
port=port,
4846
database=config.database or None,

sqlit/db/adapters/oracle.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING, Any
66

77
from ..schema import get_default_port
8-
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo
8+
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo, import_driver_module
99

1010
if TYPE_CHECKING:
1111
from ...config import ConnectionConfig
@@ -54,14 +54,12 @@ def test_query(self) -> str:
5454

5555
def connect(self, config: ConnectionConfig) -> Any:
5656
"""Connect to Oracle database."""
57-
try:
58-
import oracledb
59-
except ImportError as e:
60-
from ...db.exceptions import MissingDriverError
61-
62-
if not self.install_extra or not self.install_package:
63-
raise e
64-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
57+
oracledb = import_driver_module(
58+
"oracledb",
59+
driver_name=self.name,
60+
extra_name=self.install_extra,
61+
package_name=self.install_package,
62+
)
6563

6664
port = int(config.port or get_default_port("oracle"))
6765
# Use Easy Connect string format: host:port/service_name

sqlit/db/adapters/postgresql.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import TYPE_CHECKING, Any
66

77
from ..schema import get_default_port
8-
from .base import PostgresBaseAdapter
8+
from .base import PostgresBaseAdapter, import_driver_module
99

1010
if TYPE_CHECKING:
1111
from ...config import ConnectionConfig
@@ -32,14 +32,12 @@ def driver_import_names(self) -> tuple[str, ...]:
3232

3333
def connect(self, config: ConnectionConfig) -> Any:
3434
"""Connect to PostgreSQL database."""
35-
try:
36-
import psycopg2
37-
except ImportError as e:
38-
from ...db.exceptions import MissingDriverError
39-
40-
if not self.install_extra or not self.install_package:
41-
raise e
42-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
35+
psycopg2 = import_driver_module(
36+
"psycopg2",
37+
driver_name=self.name,
38+
extra_name=self.install_extra,
39+
package_name=self.install_package,
40+
)
4341

4442
port = int(config.port or get_default_port("postgresql"))
4543
conn = psycopg2.connect(

sqlit/db/adapters/turso.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import TYPE_CHECKING, Any
66

7-
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo
7+
from .base import ColumnInfo, DatabaseAdapter, IndexInfo, SequenceInfo, TableInfo, TriggerInfo, import_driver_module
88

99
if TYPE_CHECKING:
1010
from ...config import ConnectionConfig
@@ -54,22 +54,20 @@ def connect(self, config: ConnectionConfig) -> Any:
5454
Uses config.server for the database URL and config.password for the auth token.
5555
Supports libsql://, https://, and http:// URLs.
5656
"""
57-
try:
58-
from libsql_client import create_client_sync
59-
except ImportError as e:
60-
from ...db.exceptions import MissingDriverError
61-
62-
if not self.install_extra or not self.install_package:
63-
raise e
64-
raise MissingDriverError(self.name, self.install_extra, self.install_package) from e
57+
libsql_client = import_driver_module(
58+
"libsql_client",
59+
driver_name=self.name,
60+
extra_name=self.install_extra,
61+
package_name=self.install_package,
62+
)
6563

6664
url = config.server
6765
# Ensure URL has proper scheme
6866
if not url.startswith(("libsql://", "https://", "http://")):
6967
url = f"libsql://{url}"
7068

7169
auth_token = config.password if config.password else None
72-
client = create_client_sync(url, auth_token=auth_token)
70+
client = libsql_client.create_client_sync(url, auth_token=auth_token)
7371
return client
7472

7573
def get_databases(self, conn: Any) -> list[str]:

0 commit comments

Comments
 (0)