Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit a86edfa

Browse files
author
ansipunk
committed
S01E01
1 parent 8cbcccb commit a86edfa

File tree

3 files changed

+70
-56
lines changed

3 files changed

+70
-56
lines changed

databases/backends/asyncpg.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import asyncpg
55
from sqlalchemy.engine.interfaces import Dialect
66
from sqlalchemy.sql import ClauseElement
7-
from sqlalchemy.sql.ddl import DDLElement
87

98
from databases.backends.common.records import Record, create_column_maps
10-
from databases.backends.dialects.psycopg import get_dialect
11-
from databases.core import LOG_EXTRA, DatabaseURL
9+
from databases.backends.dialects.psycopg import compile_query, get_dialect
10+
from databases.core import DatabaseURL
1211
from databases.interfaces import (
1312
ConnectionBackend,
1413
DatabaseBackend,
@@ -88,15 +87,15 @@ async def release(self) -> None:
8887

8988
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
9089
assert self._connection is not None, "Connection is not acquired"
91-
query_str, args, result_columns = self._compile(query)
90+
query_str, args, result_columns = compile_query(query, self._dialect)
9291
rows = await self._connection.fetch(query_str, *args)
9392
dialect = self._dialect
9493
column_maps = create_column_maps(result_columns)
9594
return [Record(row, result_columns, dialect, column_maps) for row in rows]
9695

9796
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
9897
assert self._connection is not None, "Connection is not acquired"
99-
query_str, args, result_columns = self._compile(query)
98+
query_str, args, result_columns = compile_query(query, self._dialect)
10099
row = await self._connection.fetchrow(query_str, *args)
101100
if row is None:
102101
return None
@@ -124,7 +123,7 @@ async def fetch_val(
124123

125124
async def execute(self, query: ClauseElement) -> typing.Any:
126125
assert self._connection is not None, "Connection is not acquired"
127-
query_str, args, _ = self._compile(query)
126+
query_str, args, _ = compile_query(query, self._dialect)
128127
return await self._connection.fetchval(query_str, *args)
129128

130129
async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
@@ -133,51 +132,21 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
133132
# loop through multiple executes here, which should all end up
134133
# using the same prepared statement.
135134
for single_query in queries:
136-
single_query, args, _ = self._compile(single_query)
135+
single_query, args, _ = compile_query(single_query, self._dialect)
137136
await self._connection.execute(single_query, *args)
138137

139138
async def iterate(
140139
self, query: ClauseElement
141140
) -> typing.AsyncGenerator[typing.Any, None]:
142141
assert self._connection is not None, "Connection is not acquired"
143-
query_str, args, result_columns = self._compile(query)
142+
query_str, args, result_columns = compile_query(query, self._dialect)
144143
column_maps = create_column_maps(result_columns)
145144
async for row in self._connection.cursor(query_str, *args):
146145
yield Record(row, result_columns, self._dialect, column_maps)
147146

148147
def transaction(self) -> TransactionBackend:
149148
return AsyncpgTransaction(connection=self)
150149

151-
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
152-
compiled = query.compile(
153-
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
154-
)
155-
156-
if not isinstance(query, DDLElement):
157-
compiled_params = sorted(compiled.params.items())
158-
159-
mapping = {
160-
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
161-
}
162-
compiled_query = compiled.string % mapping
163-
164-
processors = compiled._bind_processors
165-
args = [
166-
processors[key](val) if key in processors else val
167-
for key, val in compiled_params
168-
]
169-
result_map = compiled._result_columns
170-
else:
171-
compiled_query = compiled.string
172-
args = []
173-
result_map = None
174-
175-
query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
176-
logger.debug(
177-
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
178-
)
179-
return compiled_query, args, result_map
180-
181150
@property
182151
def raw_connection(self) -> asyncpg.connection.Connection:
183152
assert self._connection is not None, "Connection is not acquired"

databases/backends/dialects/psycopg.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from sqlalchemy import types, util
1111
from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
1212
from sqlalchemy.engine import processors
13+
from sqlalchemy.engine.interfaces import Dialect
14+
from sqlalchemy.sql import ClauseElement
15+
from sqlalchemy.sql.ddl import DDLElement
1316
from sqlalchemy.types import Float, Numeric
1417

1518

@@ -43,7 +46,7 @@ class PGDialect_psycopg(PGDialect):
4346
execution_ctx_cls = PGExecutionContext_psycopg
4447

4548

46-
def get_dialect() -> PGDialect_psycopg:
49+
def get_dialect() -> Dialect:
4750
dialect = PGDialect_psycopg(paramstyle="pyformat")
4851
dialect.implicit_returning = True
4952
dialect.supports_native_enum = True
@@ -53,3 +56,28 @@ def get_dialect() -> PGDialect_psycopg:
5356
dialect._has_native_hstore = True
5457
dialect.supports_native_decimal = True
5558
return dialect
59+
60+
61+
def compile_query(query: ClauseElement, dialect: Dialect) -> typing.Tuple[str, list, tuple]:
62+
compiled = query.compile(dialect=dialect, compile_kwargs={"render_postcompile": True})
63+
64+
if not isinstance(query, DDLElement):
65+
compiled_params = sorted(compiled.params.items())
66+
67+
mapping = {
68+
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
69+
}
70+
compiled_query = compiled.string % mapping
71+
72+
processors = compiled._bind_processors
73+
args = [
74+
processors[key](val) if key in processors else val
75+
for key, val in compiled_params
76+
]
77+
result_map = compiled._result_columns
78+
else:
79+
compiled_query = compiled.string
80+
args = []
81+
result_map = None
82+
83+
return compiled_query, args, result_map

databases/backends/psycopg.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,68 @@
11
import typing
2-
from collections.abc import Sequence
32

3+
import psycopg
44
import psycopg_pool
5+
from sqlalchemy.engine.interfaces import Dialect
56
from sqlalchemy.sql import ClauseElement
67

7-
from databases.backends.dialects.psycopg import get_dialect
8+
from databases.backends.common.records import Record, create_column_maps
9+
from databases.backends.dialects.psycopg import compile_query, get_dialect
810
from databases.core import DatabaseURL
911
from databases.interfaces import (
1012
ConnectionBackend,
1113
DatabaseBackend,
14+
Record as RecordInterface,
1215
TransactionBackend,
1316
)
1417

1518

1619
class PsycopgBackend(DatabaseBackend):
20+
_database_url: DatabaseURL
21+
_options: typing.Dict[str, typing.Any]
22+
_dialect: Dialect
23+
_pool: typing.Optional[psycopg_pool.AsyncConnectionPool]
24+
1725
def __init__(
18-
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
26+
self,
27+
database_url: typing.Union[DatabaseURL, str],
28+
**options: typing.Dict[str, typing.Any],
1929
) -> None:
2030
self._database_url = DatabaseURL(database_url)
2131
self._options = options
2232
self._dialect = get_dialect()
23-
self._pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None
33+
self._pool = None
2434

2535
async def connect(self) -> None:
2636
if self._pool is not None:
2737
return
2838

2939
self._pool = psycopg_pool.AsyncConnectionPool(
3040
self._database_url.url, open=False, **self._options)
41+
42+
# TODO: Add configurable timeouts
3143
await self._pool.open()
3244

3345
async def disconnect(self) -> None:
3446
if self._pool is None:
3547
return
3648

49+
# TODO: Add configurable timeouts
3750
await self._pool.close()
3851
self._pool = None
3952

4053
def connection(self) -> "PsycopgConnection":
41-
return PsycopgConnection(self)
54+
return PsycopgConnection(self, self._dialect)
4255

4356

4457
class PsycopgConnection(ConnectionBackend):
45-
def __init__(self, database: PsycopgBackend) -> None:
58+
_database: PsycopgBackend
59+
_dialect: Dialect
60+
_connection: typing.Optional[psycopg.AsyncConnection]
61+
62+
def __init__(self, database: PsycopgBackend, dialect: Dialect) -> None:
4663
self._database = database
64+
self._dialect = dialect
65+
self._connection = None
4766

4867
async def acquire(self) -> None:
4968
if self._connection is not None:
@@ -62,10 +81,17 @@ async def release(self) -> None:
6281
await self._database._pool.putconn(self._connection)
6382
self._connection = None
6483

65-
async def fetch_all(self, query: ClauseElement) -> typing.List["Record"]:
84+
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
85+
if self._connection is None:
86+
raise RuntimeError("Connection is not acquired")
87+
88+
query_str, args, result_columns = compile_query(query, self._dialect)
89+
rows = await self._connection.fetch(query_str, *args)
90+
column_maps = create_column_maps(result_columns)
91+
return [Record(row, result_columns, self._dialect, column_maps) for row in rows]
6692
raise NotImplementedError() # pragma: no cover
6793

68-
async def fetch_one(self, query: ClauseElement) -> typing.Optional["Record"]:
94+
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
6995
raise NotImplementedError() # pragma: no cover
7096

7197
async def fetch_val(
@@ -107,12 +133,3 @@ async def commit(self) -> None:
107133

108134
async def rollback(self) -> None:
109135
raise NotImplementedError() # pragma: no cover
110-
111-
112-
class Record(Sequence):
113-
@property
114-
def _mapping(self) -> typing.Mapping:
115-
raise NotImplementedError() # pragma: no cover
116-
117-
def __getitem__(self, key: typing.Any) -> typing.Any:
118-
raise NotImplementedError() # pragma: no cover

0 commit comments

Comments
 (0)