4
4
import asyncpg
5
5
from sqlalchemy .engine .interfaces import Dialect
6
6
from sqlalchemy .sql import ClauseElement
7
+ from sqlalchemy .sql .ddl import DDLElement
7
8
8
9
from databases .backends .common .records import Record , create_column_maps
9
- from databases .backends .dialects .psycopg import compile_query , get_dialect
10
- from databases .core import DatabaseURL
10
+ from databases .backends .dialects .psycopg import dialect as psycopg_dialect
11
+ from databases .core import LOG_EXTRA , DatabaseURL
11
12
from databases .interfaces import (
12
13
ConnectionBackend ,
13
14
DatabaseBackend ,
@@ -24,9 +25,20 @@ def __init__(
24
25
) -> None :
25
26
self ._database_url = DatabaseURL (database_url )
26
27
self ._options = options
27
- self ._dialect = get_dialect ()
28
+ self ._dialect = self . _get_dialect ()
28
29
self ._pool = None
29
30
31
+ def _get_dialect (self ) -> Dialect :
32
+ dialect = psycopg_dialect (paramstyle = "pyformat" )
33
+ dialect .implicit_returning = True
34
+ dialect .supports_native_enum = True
35
+ dialect .supports_smallserial = True # 9.2+
36
+ dialect ._backslash_escapes = False
37
+ dialect .supports_sane_multi_rowcount = True # psycopg 2.0.9+
38
+ dialect ._has_native_hstore = True
39
+ dialect .supports_native_decimal = True
40
+ return dialect
41
+
30
42
def _get_connection_kwargs (self ) -> dict :
31
43
url_options = self ._database_url .options
32
44
@@ -87,15 +99,15 @@ async def release(self) -> None:
87
99
88
100
async def fetch_all (self , query : ClauseElement ) -> typing .List [RecordInterface ]:
89
101
assert self ._connection is not None , "Connection is not acquired"
90
- query_str , args , result_columns = compile_query ( query , self ._dialect )
102
+ query_str , args , result_columns = self ._compile ( query )
91
103
rows = await self ._connection .fetch (query_str , * args )
92
104
dialect = self ._dialect
93
105
column_maps = create_column_maps (result_columns )
94
106
return [Record (row , result_columns , dialect , column_maps ) for row in rows ]
95
107
96
108
async def fetch_one (self , query : ClauseElement ) -> typing .Optional [RecordInterface ]:
97
109
assert self ._connection is not None , "Connection is not acquired"
98
- query_str , args , result_columns = compile_query ( query , self ._dialect )
110
+ query_str , args , result_columns = self ._compile ( query )
99
111
row = await self ._connection .fetchrow (query_str , * args )
100
112
if row is None :
101
113
return None
@@ -123,7 +135,7 @@ async def fetch_val(
123
135
124
136
async def execute (self , query : ClauseElement ) -> typing .Any :
125
137
assert self ._connection is not None , "Connection is not acquired"
126
- query_str , args , _ = compile_query ( query , self ._dialect )
138
+ query_str , args , _ = self ._compile ( query )
127
139
return await self ._connection .fetchval (query_str , * args )
128
140
129
141
async def execute_many (self , queries : typing .List [ClauseElement ]) -> None :
@@ -132,25 +144,55 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
132
144
# loop through multiple executes here, which should all end up
133
145
# using the same prepared statement.
134
146
for single_query in queries :
135
- single_query , args , _ = compile_query ( single_query , self ._dialect )
147
+ single_query , args , _ = self ._compile ( single_query )
136
148
await self ._connection .execute (single_query , * args )
137
149
138
150
async def iterate (
139
151
self , query : ClauseElement
140
152
) -> typing .AsyncGenerator [typing .Any , None ]:
141
153
assert self ._connection is not None , "Connection is not acquired"
142
- query_str , args , result_columns = compile_query ( query , self ._dialect )
154
+ query_str , args , result_columns = self ._compile ( query )
143
155
column_maps = create_column_maps (result_columns )
144
156
async for row in self ._connection .cursor (query_str , * args ):
145
157
yield Record (row , result_columns , self ._dialect , column_maps )
146
158
147
159
def transaction (self ) -> TransactionBackend :
148
160
return AsyncpgTransaction (connection = self )
149
161
150
- @property
151
- def raw_connection (self ) -> asyncpg .connection .Connection :
152
- assert self ._connection is not None , "Connection is not acquired"
153
- return self ._connection
162
+ def _compile (self , query : ClauseElement ) -> typing .Tuple [str , list , tuple ]:
163
+ compiled = query .compile (
164
+ dialect = self ._dialect , compile_kwargs = {"render_postcompile" : True }
165
+ )
166
+
167
+ if not isinstance (query , DDLElement ):
168
+ compiled_params = sorted (compiled .params .items ())
169
+
170
+ mapping = {
171
+ key : "$" + str (i ) for i , (key , _ ) in enumerate (compiled_params , start = 1 )
172
+ }
173
+ compiled_query = compiled .string % mapping
174
+
175
+ processors = compiled ._bind_processors
176
+ args = [
177
+ processors [key ](val ) if key in processors else val
178
+ for key , val in compiled_params
179
+ ]
180
+ result_map = compiled ._result_columns
181
+ else :
182
+ compiled_query = compiled .string
183
+ args = []
184
+ result_map = None
185
+
186
+ query_message = compiled_query .replace (" \n " , " " ).replace ("\n " , " " )
187
+ logger .debug (
188
+ "Query: %s Args: %s" , query_message , repr (tuple (args )), extra = LOG_EXTRA
189
+ )
190
+ return compiled_query , args , result_map
191
+
192
+ @property
193
+ def raw_connection (self ) -> asyncpg .connection .Connection :
194
+ assert self ._connection is not None , "Connection is not acquired"
195
+ return self ._connection
154
196
155
197
156
198
class AsyncpgTransaction (TransactionBackend ):
0 commit comments