5
5
from types import TracebackType
6
6
from urllib .parse import SplitResult , parse_qsl , urlsplit
7
7
8
+ from sqlalchemy import text
8
9
from sqlalchemy .engine import RowProxy
9
10
from sqlalchemy .sql import ClauseElement
10
11
@@ -88,27 +89,36 @@ async def __aexit__(
88
89
) -> None :
89
90
await self .disconnect ()
90
91
91
- async def fetch_all (self , query : ClauseElement ) -> typing .List [RowProxy ]:
92
+ async def fetch_all (
93
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
94
+ ) -> typing .List [RowProxy ]:
92
95
async with self .connection () as connection :
93
- return await connection .fetch_all (query = query )
96
+ return await connection .fetch_all (query = self . _build_query ( query , values ) )
94
97
95
- async def fetch_one (self , query : ClauseElement ) -> RowProxy :
98
+ async def fetch_one (
99
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
100
+ ) -> RowProxy :
96
101
async with self .connection () as connection :
97
- return await connection .fetch_one (query = query )
102
+ return await connection .fetch_one (query = self . _build_query ( query , values ) )
98
103
99
- async def execute (self , query : ClauseElement , values : dict = None ) -> typing .Any :
104
+ async def execute (
105
+ self , query : typing .Union [ClauseElement , str ], values : dict = None
106
+ ) -> typing .Any :
100
107
async with self .connection () as connection :
101
- return await connection .execute (query = query , values = values )
108
+ return await connection .execute (self . _build_query ( query , values ) )
102
109
103
- async def execute_many (self , query : ClauseElement , values : list ) -> None :
110
+ async def execute_many (
111
+ self , query : typing .Union [ClauseElement , str ], values : list
112
+ ) -> None :
104
113
async with self .connection () as connection :
105
- return await connection .execute_many (query = query , values = values )
114
+ queries = [self ._build_query (query , values_set ) for values_set in values ]
115
+ return await connection .execute_many (queries )
106
116
107
117
async def iterate (
108
- self , query : ClauseElement
118
+ self , query : typing . Union [ ClauseElement , str ], values : dict = None
109
119
) -> typing .AsyncGenerator [RowProxy , None ]:
110
120
async with self .connection () as connection :
111
- async for record in connection .iterate (query ):
121
+ async for record in connection .iterate (self . _build_query ( query , values ) ):
112
122
yield record
113
123
114
124
def connection (self ) -> "Connection" :
@@ -125,6 +135,19 @@ def connection(self) -> "Connection":
125
135
def transaction (self , * , force_rollback : bool = False ) -> "Transaction" :
126
136
return self .connection ().transaction (force_rollback = force_rollback )
127
137
138
+ @staticmethod
139
+ def _build_query (
140
+ query : typing .Union [ClauseElement , str ], values : dict = None
141
+ ) -> ClauseElement :
142
+ if isinstance (query , str ):
143
+ query = text (query )
144
+
145
+ return query .bindparams (** values ) if values is not None else query
146
+ elif values :
147
+ return query .values (** values )
148
+
149
+ return query
150
+
128
151
129
152
class Connection :
130
153
def __init__ (self , backend : DatabaseBackend ) -> None :
@@ -162,11 +185,11 @@ async def fetch_all(self, query: ClauseElement) -> typing.Any:
162
185
async def fetch_one (self , query : ClauseElement ) -> typing .Any :
163
186
return await self ._connection .fetch_one (query = query )
164
187
165
- async def execute (self , query : ClauseElement , values : dict = None ) -> typing .Any :
166
- return await self ._connection .execute (query , values )
188
+ async def execute (self , query : ClauseElement ) -> typing .Any :
189
+ return await self ._connection .execute (query )
167
190
168
- async def execute_many (self , query : ClauseElement , values : list ) -> None :
169
- await self ._connection .execute_many (query , values )
191
+ async def execute_many (self , queries : typing . List [ ClauseElement ] ) -> None :
192
+ await self ._connection .execute_many (queries )
170
193
171
194
async def iterate (
172
195
self , query : ClauseElement
0 commit comments