5
5
from types import TracebackType
6
6
from urllib .parse import SplitResult , parse_qsl , urlsplit
7
7
8
+ from sqlalchemy import text
8
9
from sqlalchemy .engine import RowProxy
9
10
from sqlalchemy .sql import ClauseElement
10
11
@@ -88,27 +89,35 @@ async def __aexit__(
88
89
) -> None :
89
90
await self .disconnect ()
90
91
91
- async def fetch_all (self , query : ClauseElement ) -> typing .List [RowProxy ]:
92
+ async def fetch_all (
93
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
94
+ ) -> typing .List [RowProxy ]:
92
95
async with self .connection () as connection :
93
- return await connection .fetch_all (query = query )
96
+ return await connection .fetch_all (query , values )
94
97
95
- async def fetch_one (self , query : ClauseElement ) -> RowProxy :
98
+ async def fetch_one (
99
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
100
+ ) -> RowProxy :
96
101
async with self .connection () as connection :
97
- return await connection .fetch_one (query = query )
102
+ return await connection .fetch_one (query , values )
98
103
99
- async def execute (self , query : ClauseElement , values : dict = None ) -> typing .Any :
104
+ async def execute (
105
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
106
+ ) -> typing .Any :
100
107
async with self .connection () as connection :
101
- return await connection .execute (query = query , values = values )
108
+ return await connection .execute (query , values )
102
109
103
- async def execute_many (self , query : ClauseElement , values : list ) -> None :
110
+ async def execute_many (
111
+ self , query : typing .Union [ClauseElement , str ], values : list
112
+ ) -> None :
104
113
async with self .connection () as connection :
105
- return await connection .execute_many (query = query , values = values )
114
+ return await connection .execute_many (query , values )
106
115
107
116
async def iterate (
108
- self , query : ClauseElement
117
+ self , query : typing . Union [ ClauseElement , str ], values : dict = None
109
118
) -> typing .AsyncGenerator [RowProxy , None ]:
110
119
async with self .connection () as connection :
111
- async for record in connection .iterate (query ):
120
+ async for record in connection .iterate (query , values ):
112
121
yield record
113
122
114
123
def connection (self ) -> "Connection" :
@@ -156,22 +165,31 @@ async def __aexit__(
156
165
if self ._connection_counter == 0 :
157
166
await self ._connection .release ()
158
167
159
- async def fetch_all (self , query : ClauseElement ) -> typing .Any :
160
- return await self ._connection .fetch_all (query = query )
168
+ async def fetch_all (
169
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
170
+ ) -> typing .Any :
171
+ return await self ._connection .fetch_all (self ._build_query (query , values ))
161
172
162
- async def fetch_one (self , query : ClauseElement ) -> typing .Any :
163
- return await self ._connection .fetch_one (query = query )
173
+ async def fetch_one (
174
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
175
+ ) -> typing .Any :
176
+ return await self ._connection .fetch_one (self ._build_query (query , values ))
164
177
165
- async def execute (self , query : ClauseElement , values : dict = None ) -> typing .Any :
166
- return await self ._connection .execute (query , values )
178
+ async def execute (
179
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
180
+ ) -> typing .Any :
181
+ return await self ._connection .execute (self ._build_query (query , values ))
167
182
168
- async def execute_many (self , query : ClauseElement , values : list ) -> None :
169
- await self ._connection .execute_many (query , values )
183
+ async def execute_many (
184
+ self , query : typing .Union [ClauseElement , str ], values : list
185
+ ) -> None :
186
+ queries = [self ._build_query (query , values_set ) for values_set in values ]
187
+ await self ._connection .execute_many (queries )
170
188
171
189
async def iterate (
172
- self , query : ClauseElement
190
+ self , query : typing . Union [ ClauseElement , str ], values : dict = None
173
191
) -> typing .AsyncGenerator [typing .Any , None ]:
174
- async for record in self ._connection .iterate (query ):
192
+ async for record in self ._connection .iterate (self . _build_query ( query , values ) ):
175
193
yield record
176
194
177
195
def transaction (self , * , force_rollback : bool = False ) -> "Transaction" :
@@ -181,6 +199,19 @@ def transaction(self, *, force_rollback: bool = False) -> "Transaction":
181
199
def raw_connection (self ) -> typing .Any :
182
200
return self ._connection .raw_connection
183
201
202
+ @staticmethod
203
+ def _build_query (
204
+ query : typing .Union [ClauseElement , str ], values : dict = None
205
+ ) -> ClauseElement :
206
+ if isinstance (query , str ):
207
+ query = text (query )
208
+
209
+ return query .bindparams (** values ) if values is not None else query
210
+ elif values :
211
+ return query .values (** values )
212
+
213
+ return query
214
+
184
215
185
216
class Transaction :
186
217
def __init__ (self , connection : Connection , force_rollback : bool ) -> None :
0 commit comments