diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 9ad12f63..d2ad98a6 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -168,7 +168,9 @@ async def execute(self, query: ClauseElement) -> typing.Any: finally: cursor.close() - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" cursor = await self._connection.cursor() try: diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index e15dfa45..8c327c71 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -158,7 +158,9 @@ async def execute(self, query: ClauseElement) -> typing.Any: finally: await cursor.close() - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" async with self._connection.cursor() as cursor: try: diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 2a0a8425..2922f3f2 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -158,13 +158,14 @@ async def execute(self, query: ClauseElement) -> typing.Any: finally: await cursor.close() - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" cursor = await self._connection.cursor() + query_str, values = self._compile_many(queries, values) try: - for single_query in queries: - single_query, args, context = self._compile(single_query) - await cursor.execute(single_query, args) + await cursor.executemany(query_str, values) finally: await cursor.close() @@ -220,6 +221,21 @@ def _compile( logger.debug("Query: %s Args: %s", query_message, repr(args), extra=LOG_EXTRA) return compiled.string, args, CompilationContext(execution_context) + def _compile_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> typing.Tuple[str, list]: + compiled = queries[0].compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + if not isinstance(queries[0], DDLElement): + for args in values: + for key, val in args.items(): + if key in compiled._bind_processors: + args[key] = compiled._bind_processors[key](val) + else: + values = [] + return compiled.string, values + @property def raw_connection(self) -> aiomysql.connection.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 3e1a6fff..77a40171 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -216,14 +216,12 @@ async def execute(self, query: ClauseElement) -> typing.Any: query_str, args, result_columns = self._compile(query) return await self._connection.fetchval(query_str, *args) - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" - # asyncpg uses prepared statements under the hood, so we just - # loop through multiple executes here, which should all end up - # using the same prepared statement. - for single_query in queries: - single_query, args, result_columns = self._compile(single_query) - await self._connection.execute(single_query, *args) + query_str, values = self._compile_many(queries, values) + await self._connection.executemany(query_str, values) async def iterate( self, query: ClauseElement @@ -268,6 +266,30 @@ def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: ) return compiled_query, args, result_map + def _compile_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> typing.Tuple[str, list]: + compiled = queries[0].compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + new_values = [] + if not isinstance(queries[0], DDLElement): + for args in values: + sorted_args = sorted(args.items()) + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(sorted_args, start=1) + } + compiled_query = compiled.string % mapping + processors = compiled._bind_processors + values = [ + processors[key](val) if key in processors else val + for key, val in sorted_args + ] + new_values.append(values) + else: + compiled_query = compiled.string + return compiled_query, new_values + @staticmethod def _create_column_maps( result_columns: tuple, diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 9626dcf8..5e88c061 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -135,10 +135,13 @@ async def execute(self, query: ClauseElement) -> typing.Any: return cursor.rowcount return cursor.lastrowid - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: assert self._connection is not None, "Connection is not acquired" - for single_query in queries: - await self.execute(single_query) + query_str, values = self._compile_many(queries, values) + async with self._connection.cursor() as cursor: + await cursor.executemany(query_str, values) async def iterate( self, query: ClauseElement @@ -194,6 +197,26 @@ def _compile( ) return compiled.string, args, CompilationContext(execution_context) + def _compile_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> typing.Tuple[str, list]: + compiled = queries[0].compile( + dialect=self._dialect, compile_kwargs={"render_postcompile": True} + ) + new_values = [] + if not isinstance(queries[0], DDLElement): + for args in values: + temp_arr = [] + for key in compiled.positiontup: + raw_val = args[key] + if key in compiled._bind_processors: + val = compiled._bind_processors[key](raw_val) + else: + val = raw_val + temp_arr.append(val) + new_values.append(temp_arr) + return compiled.string, new_values + @property def raw_connection(self) -> aiosqlite.core.Connection: assert self._connection is not None, "Connection is not acquired" diff --git a/databases/core.py b/databases/core.py index 0e27227c..310d5b18 100644 --- a/databases/core.py +++ b/databases/core.py @@ -278,7 +278,7 @@ async def execute_many( ) -> None: queries = [self._build_query(query, values_set) for values_set in values] async with self._query_lock: - await self._connection.execute_many(queries) + await self._connection.execute_many(queries, values) async def iterate( self, query: typing.Union[ClauseElement, str], values: dict = None diff --git a/databases/interfaces.py b/databases/interfaces.py index c2109a23..6c7d649c 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -37,7 +37,9 @@ async def fetch_val( async def execute(self, query: ClauseElement) -> typing.Any: raise NotImplementedError() # pragma: no cover - async def execute_many(self, queries: typing.List[ClauseElement]) -> None: + async def execute_many( + self, queries: typing.List[ClauseElement], values: typing.List[dict] + ) -> None: raise NotImplementedError() # pragma: no cover async def iterate(