Skip to content

Commit 9e48e85

Browse files
zeromakejettify
authored andcommitted
fix: support executemany (#324)
* fix: support executemany * fix: flake8 * fix: flake8 * fix: test support executemany( * fix: local variable referenced before assignment * test: add more executemany test * fix: flake8 * fix: review * fix: lint
1 parent 2446872 commit 9e48e85

File tree

2 files changed

+101
-27
lines changed

2 files changed

+101
-27
lines changed

aiomysql/sa/connection.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
from ..utils import _TransactionContextManager, _SAConnectionContextManager
1414

1515

16+
def noop(k):
17+
return k
18+
19+
1620
class SAConnection:
1721

1822
def __init__(self, connection, engine, compiled_cache=None):
@@ -64,16 +68,79 @@ def execute(self, query, *multiparams, **params):
6468
coro = self._execute(query, *multiparams, **params)
6569
return _SAConnectionContextManager(coro)
6670

71+
def _base_params(self, query, dp, compiled, is_update):
72+
"""
73+
handle params
74+
"""
75+
if dp and isinstance(dp, (list, tuple)):
76+
if is_update:
77+
dp = {c.key: pval for c, pval in zip(query.table.c, dp)}
78+
else:
79+
raise exc.ArgumentError(
80+
"Don't mix sqlalchemy SELECT "
81+
"clause with positional "
82+
"parameters"
83+
)
84+
compiled_params = compiled.construct_params(dp)
85+
processors = compiled._bind_processors
86+
params = [{
87+
key: processors.get(key, noop)(compiled_params[key])
88+
for key in compiled_params
89+
}]
90+
post_processed_params = self._dialect.execute_sequence_format(params)
91+
return post_processed_params[0]
92+
93+
async def _executemany(self, query, dps, cursor):
94+
"""
95+
executemany
96+
"""
97+
result_map = None
98+
if isinstance(query, str):
99+
await cursor.executemany(query, dps)
100+
elif isinstance(query, DDLElement):
101+
raise exc.ArgumentError(
102+
"Don't mix sqlalchemy DDL clause "
103+
"and execution with parameters"
104+
)
105+
elif isinstance(query, ClauseElement):
106+
compiled = query.compile(dialect=self._dialect)
107+
params = []
108+
is_update = isinstance(query, UpdateBase)
109+
for dp in dps:
110+
params.append(
111+
self._base_params(
112+
query,
113+
dp,
114+
compiled,
115+
is_update,
116+
)
117+
)
118+
await cursor.executemany(str(compiled), params)
119+
result_map = compiled._result_columns
120+
else:
121+
raise exc.ArgumentError(
122+
"sql statement should be str or "
123+
"SQLAlchemy data "
124+
"selection/modification clause"
125+
)
126+
ret = await create_result_proxy(
127+
self,
128+
cursor,
129+
self._dialect,
130+
result_map
131+
)
132+
self._weak_results.add(ret)
133+
return ret
134+
67135
async def _execute(self, query, *multiparams, **params):
68136
cursor = await self._connection.cursor()
69137
dp = _distill_params(multiparams, params)
70138
if len(dp) > 1:
71-
raise exc.ArgumentError("aiomysql doesn't support executemany")
139+
return await self._executemany(query, dp, cursor)
72140
elif dp:
73141
dp = dp[0]
74142

75143
result_map = None
76-
77144
if isinstance(query, str):
78145
await cursor.execute(query, dp or None)
79146
elif isinstance(query, ClauseElement):
@@ -90,35 +157,20 @@ async def _execute(self, query, *multiparams, **params):
90157
compiled = query.compile(dialect=self._dialect)
91158

92159
if not isinstance(query, DDLElement):
93-
if dp and isinstance(dp, (list, tuple)):
94-
if isinstance(query, UpdateBase):
95-
dp = {c.key: pval
96-
for c, pval in zip(query.table.c, dp)}
97-
else:
98-
raise exc.ArgumentError("Don't mix sqlalchemy SELECT "
99-
"clause with positional "
100-
"parameters")
101-
compiled_parameters = [compiled.construct_params(
102-
dp)]
103-
processed_parameters = []
104-
processors = compiled._bind_processors
105-
for compiled_params in compiled_parameters:
106-
params = {key: (processors[key](compiled_params[key])
107-
if key in processors
108-
else compiled_params[key])
109-
for key in compiled_params}
110-
processed_parameters.append(params)
111-
post_processed_params = self._dialect.execute_sequence_format(
112-
processed_parameters)
160+
post_processed_params = self._base_params(
161+
query,
162+
dp,
163+
compiled,
164+
isinstance(query, UpdateBase)
165+
)
113166
result_map = compiled._result_columns
114-
115167
else:
116168
if dp:
117169
raise exc.ArgumentError("Don't mix sqlalchemy DDL clause "
118170
"and execution with parameters")
119-
post_processed_params = [compiled.construct_params()]
171+
post_processed_params = compiled.construct_params()
120172
result_map = None
121-
await cursor.execute(str(compiled), post_processed_params[0])
173+
await cursor.execute(str(compiled), post_processed_params)
122174
else:
123175
raise exc.ArgumentError("sql statement should be str or "
124176
"SQLAlchemy data "

tests/sa/test_sa_connection.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from sqlalchemy import MetaData, Table, Column, Integer, String
1010
from sqlalchemy.schema import DropTable, CreateTable
11+
from sqlalchemy.sql.expression import bindparam
1112

1213

1314
meta = MetaData()
@@ -269,10 +270,31 @@ async def go():
269270
def test_raw_insert_with_executemany(self):
270271
async def go():
271272
conn = await self.connect()
273+
# with self.assertRaises(sa.ArgumentError):
274+
await conn.execute(
275+
"INSERT INTO sa_tbl (id, name) VALUES (%(id)s, %(name)s)",
276+
[{"id": 2, "name": 'third'}, {"id": 3, "name": 'forth'}])
277+
await conn.execute(
278+
tbl.update().where(
279+
tbl.c.id == bindparam("id")
280+
).values(
281+
{"name": bindparam("name")}
282+
),
283+
[
284+
{"id": 2, "name": "t2"},
285+
{"id": 3, "name": "t3"}
286+
]
287+
)
288+
with self.assertRaises(sa.ArgumentError):
289+
await conn.execute(
290+
DropTable(tbl),
291+
[{}, {}]
292+
)
272293
with self.assertRaises(sa.ArgumentError):
273294
await conn.execute(
274-
"INSERT INTO sa_tbl (id, name) VALUES (%(id)s, %(name)s)",
275-
[(2, 'third'), (3, 'forth')])
295+
{},
296+
[{}, {}]
297+
)
276298
self.loop.run_until_complete(go())
277299

278300
def test_raw_select_with_wildcard(self):

0 commit comments

Comments
 (0)