|
1 | 1 | import logging
|
2 | 2 | import typing
|
3 |
| -from collections.abc import Sequence |
4 | 3 |
|
5 | 4 | import asyncpg
|
6 | 5 | from sqlalchemy.dialects.postgresql import pypostgresql
|
@@ -217,14 +216,12 @@ async def execute(self, query: ClauseElement) -> typing.Any:
|
217 | 216 | query_str, args, result_columns = self._compile(query)
|
218 | 217 | return await self._connection.fetchval(query_str, *args)
|
219 | 218 |
|
220 |
| - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: |
| 219 | + async def execute_many( |
| 220 | + self, queries: typing.List[ClauseElement], values: typing.List[dict] |
| 221 | + ) -> None: |
221 | 222 | assert self._connection is not None, "Connection is not acquired"
|
222 |
| - # asyncpg uses prepared statements under the hood, so we just |
223 |
| - # loop through multiple executes here, which should all end up |
224 |
| - # using the same prepared statement. |
225 |
| - for single_query in queries: |
226 |
| - single_query, args, result_columns = self._compile(single_query) |
227 |
| - await self._connection.execute(single_query, *args) |
| 223 | + query_str, values = self._compile_many(queries, values) |
| 224 | + await self._connection.executemany(query_str, values) |
228 | 225 |
|
229 | 226 | async def iterate(
|
230 | 227 | self, query: ClauseElement
|
@@ -269,6 +266,18 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
|
269 | 266 | )
|
270 | 267 | return compiled_query, args, result_map
|
271 | 268 |
|
| 269 | + def _compile_many( |
| 270 | + self, queries: typing.List[ClauseElement], values: typing.List[dict] |
| 271 | + ) -> typing.Tuple[str, list]: |
| 272 | + compiled = queries[0].compile( |
| 273 | + dialect=self._dialect, compile_kwargs={"render_postcompile": True} |
| 274 | + ) |
| 275 | + for args in values: |
| 276 | + for key, val in args.items(): |
| 277 | + if key in compiled._bind_processors: |
| 278 | + args[key] = compiled._bind_processors[key](val) |
| 279 | + return compiled.string, values |
| 280 | + |
272 | 281 | @staticmethod
|
273 | 282 | def _create_column_maps(
|
274 | 283 | result_columns: tuple,
|
|
0 commit comments