Skip to content

Commit 224b420

Browse files
Merge pull request #502 from Mause/refactor/type-cleanliness
refactor: type cleanliness for SQLA2 support
2 parents 95edc35 + ebd81f3 commit 224b420

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

duckdb_engine/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
33

44
import duckdb
5-
from sqlalchemy import pool
5+
from sqlalchemy import pool, text
66
from sqlalchemy import types as sqltypes
77
from sqlalchemy import util
88
from sqlalchemy.dialects.postgresql.base import PGInspector
@@ -61,7 +61,7 @@ def cursor(self) -> "Connection":
6161
return self
6262

6363
def fetchmany(self, size: Optional[int] = None) -> List:
64-
if hasattr(self.c, "fetchmany"):
64+
if hasattr(self.__c, "fetchmany"):
6565
# fetchmany was only added in 0.5.0
6666
if size is None:
6767
return self.__c.fetchmany()
@@ -223,8 +223,10 @@ def get_view_names(
223223
include: Optional[Any] = None,
224224
**kw: Any,
225225
) -> Any:
226-
s = "SELECT table_name FROM information_schema.tables WHERE table_type='VIEW' and table_schema=?"
227-
rs = connection.execute(s, schema if schema is not None else "main")
226+
s = "SELECT table_name FROM information_schema.tables WHERE table_type='VIEW' and table_schema=:schema_name"
227+
rs = connection.execute(
228+
text(s), {"schema_name": schema if schema is not None else "main"}
229+
)
228230

229231
return [row[0] for row in rs]
230232

duckdb_engine/tests/test_basic.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,11 @@ def test_get_views(engine: Engine) -> None:
156156
views = engine.dialect.get_view_names(con)
157157
assert views == []
158158

159-
engine.execute(text("create view test as select 1"))
160-
engine.execute(
159+
con.execute(text("create view test as select 1"))
160+
con.execute(
161161
text("create schema scheme; create view scheme.schema_test as select 1")
162162
)
163163

164-
con = engine.connect()
165164
views = engine.dialect.get_view_names(con)
166165
assert views == ["test"]
167166

@@ -183,7 +182,7 @@ def test_preload_extension() -> None:
183182
# check that we get an error indicating that the extension was loaded
184183
with engine.connect() as conn, raises(Exception, match="HTTP HEAD"):
185184
conn.execute(
186-
"SELECT * FROM read_parquet('https://domain/path/to/file.parquet');"
185+
text("SELECT * FROM read_parquet('https://domain/path/to/file.parquet');")
187186
)
188187

189188

@@ -307,8 +306,9 @@ def test_inmemory() -> None:
307306
shell = InteractiveShell()
308307
shell.run_cell("""import sqlalchemy as sa""")
309308
shell.run_cell("""eng = sa.create_engine("duckdb:///:memory:")""")
310-
shell.run_cell("""eng.execute("CREATE TABLE t (x int)")""")
311-
res = shell.run_cell("""eng.execute("SHOW TABLES").fetchall()""")
309+
shell.run_cell("""conn = eng.connect()""")
310+
shell.run_cell("""conn.execute(sa.text("CREATE TABLE t (x int)"))""")
311+
res = shell.run_cell("""conn.execute(sa.text("SHOW TABLES")).fetchall()""")
312312

313313
assert res.result == [("t",)]
314314

@@ -329,7 +329,8 @@ def test_config(tmp_path: Path) -> None:
329329
DBAPIError,
330330
match='Cannot execute statement of type "CREATE" (on database "test" which is attached )?in read-only mode!',
331331
):
332-
eng.execute("create table hello2 (i int)")
332+
with eng.connect() as conn:
333+
conn.execute(text("create table hello2 (i int)"))
333334

334335

335336
def test_do_ping(tmp_path: Path, caplog: LogCaptureFixture) -> None:
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import pandas as pd
2+
from sqlalchemy import text
23
from sqlalchemy.engine import Engine
34

45

56
def test_integration(engine: Engine) -> None:
6-
engine.execute("register", ("test_df", pd.DataFrame([{"a": 1}])))
7+
with engine.connect() as conn:
8+
execute = (
9+
conn.exec_driver_sql if hasattr(conn, "exec_driver_sql") else conn.execute
10+
)
11+
params = ("test_df", pd.DataFrame([{"a": 1}]))
12+
execute("register", params) # type: ignore[operator]
713

8-
engine.execute("select * from test_df")
14+
conn.execute(text("select * from test_df"))

0 commit comments

Comments
 (0)