Skip to content

Commit a9bc33d

Browse files
authored
Merge branch 'main' into greenlet-segfault
2 parents 709a605 + ea2eb35 commit a9bc33d

File tree

4 files changed

+149
-1
lines changed

4 files changed

+149
-1
lines changed

newrelic/hooks/database_aiomysql.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,27 @@ async def __aenter__(self):
3939
async def __aexit__(self, exc, val, tb):
4040
return await self.__wrapped__.__aexit__(exc, val, tb)
4141

42+
def __await__(self):
43+
# Handle bidirectional generator protocol using code from generator_wrapper
44+
g = self.__wrapped__.__await__()
45+
try:
46+
yielded = g.send(None)
47+
while True:
48+
try:
49+
sent = yield yielded
50+
except GeneratorExit as e:
51+
g.close()
52+
raise
53+
except BaseException as e:
54+
yielded = g.throw(e)
55+
else:
56+
yielded = g.send(sent)
57+
except StopIteration as e:
58+
# Catch the StopIteration and wrap the return value.
59+
cursor = e.value
60+
wrapped_cursor = self.__cursor_wrapper__(cursor, self._nr_dbapi2_module, self._nr_connect_params, self._nr_cursor_args)
61+
return wrapped_cursor # Return here instead of raising StopIteration to properly follow generator protocol
62+
4263

4364
class AsyncConnectionWrapper(DBAPI2AsyncConnectionWrapper):
4465
__cursor_wrapper__ = AsyncCursorContextManagerWrapper
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from aiomysql.sa import create_engine
16+
from sqlalchemy.orm import declarative_base
17+
from sqlalchemy import Integer, String, Column, Float
18+
from sqlalchemy.schema import CreateTable, DropTable
19+
20+
from testing_support.db_settings import mysql_settings
21+
from testing_support.util import instance_hostname
22+
from testing_support.validators.validate_database_trace_inputs import (
23+
validate_database_trace_inputs,
24+
)
25+
from testing_support.validators.validate_transaction_metrics import (
26+
validate_transaction_metrics,
27+
)
28+
29+
from newrelic.api.background_task import background_task
30+
31+
DB_SETTINGS = mysql_settings()[0]
32+
TABLE_NAME = f"datastore_aiomysql_orm_{DB_SETTINGS['namespace']}"
33+
PROCEDURE_NAME = f"hello_{DB_SETTINGS['namespace']}"
34+
35+
HOST = instance_hostname(DB_SETTINGS["host"])
36+
PORT = DB_SETTINGS["port"]
37+
38+
39+
Base = declarative_base()
40+
41+
class ABCModel(Base):
42+
__tablename__ = TABLE_NAME
43+
44+
a = Column(Integer, primary_key=True)
45+
b = Column(Float)
46+
c = Column(String(100))
47+
48+
49+
ABCTable = ABCModel.__table__
50+
51+
52+
async def exercise(engine):
53+
async with engine.acquire() as conn:
54+
async with conn.begin():
55+
await conn.execute(DropTable(ABCTable, if_exists=True))
56+
await conn.execute(CreateTable(ABCTable))
57+
58+
input_rows = [(1, 1.0, "1.0"), (2, 2.2, "2.2"), (3, 3.3, "3.3")]
59+
await conn.execute(ABCTable.insert().values(input_rows))
60+
cursor = await conn.execute(ABCTable.select())
61+
62+
rows = []
63+
async for row in cursor:
64+
rows.append(row)
65+
66+
assert rows == input_rows, f"Expected: {input_rows}, Got: {rows}"
67+
68+
await conn.execute(ABCTable.update().where(ABCTable.columns.a == 1).values((4, 4.0, "4.0")))
69+
await conn.execute(ABCTable.delete().where(ABCTable.columns.a == 2))
70+
71+
72+
SCOPED_METRICS = [
73+
("Function/aiomysql.pool:Pool._acquire", 2),
74+
(f"Datastore/statement/MySQL/{TABLE_NAME}/select", 1),
75+
(f"Datastore/statement/MySQL/{TABLE_NAME}/insert", 1),
76+
(f"Datastore/statement/MySQL/{TABLE_NAME}/update", 1),
77+
(f"Datastore/statement/MySQL/{TABLE_NAME}/delete", 1),
78+
("Datastore/operation/MySQL/drop", 1),
79+
("Datastore/operation/MySQL/create", 1),
80+
("Datastore/operation/MySQL/commit", 1),
81+
("Datastore/operation/MySQL/begin", 1),
82+
]
83+
84+
ROLLUP_METRICS = [
85+
("Function/aiomysql.pool:Pool._acquire", 2),
86+
("Datastore/all", 10),
87+
("Datastore/allOther", 10),
88+
("Datastore/MySQL/all", 10),
89+
("Datastore/MySQL/allOther", 10),
90+
(f"Datastore/statement/MySQL/{TABLE_NAME}/select", 1),
91+
(f"Datastore/statement/MySQL/{TABLE_NAME}/insert", 1),
92+
(f"Datastore/statement/MySQL/{TABLE_NAME}/update", 1),
93+
(f"Datastore/statement/MySQL/{TABLE_NAME}/delete", 1),
94+
("Datastore/operation/MySQL/select", 1),
95+
("Datastore/operation/MySQL/insert", 1),
96+
("Datastore/operation/MySQL/update", 1),
97+
("Datastore/operation/MySQL/delete", 1),
98+
("Datastore/operation/MySQL/drop", 1),
99+
("Datastore/operation/MySQL/create", 1),
100+
("Datastore/operation/MySQL/commit", 1),
101+
(f"Datastore/instance/MySQL/{HOST}/{PORT}", 8),
102+
]
103+
104+
@validate_transaction_metrics(
105+
"test_sqlalchemy:test_execute_via_engine",
106+
scoped_metrics=SCOPED_METRICS,
107+
rollup_metrics=ROLLUP_METRICS,
108+
background_task=True,
109+
)
110+
@validate_database_trace_inputs(sql_parameters_type=dict)
111+
@background_task()
112+
def test_execute_via_engine(loop):
113+
async def _test():
114+
engine = await create_engine(
115+
db=DB_SETTINGS["name"],
116+
user=DB_SETTINGS["user"],
117+
password=DB_SETTINGS["password"],
118+
host=DB_SETTINGS["host"],
119+
port=DB_SETTINGS["port"],
120+
autocommit=True,
121+
)
122+
123+
async with engine:
124+
await exercise(engine)
125+
126+
loop.run_until_complete(_test())

tests/testing_support/validators/validate_database_trace_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _bind_params(
5555
assert isinstance(cursor_params[0], tuple)
5656
assert isinstance(cursor_params[1], dict)
5757

58-
assert sql_parameters is None or isinstance(sql_parameters, sql_parameters_type)
58+
assert sql_parameters is None or isinstance(sql_parameters, sql_parameters_type), f"Expected: {sql_parameters_type} Got: {type(sql_parameters)}"
5959

6060
if execute_params is not None:
6161
assert len(execute_params) == 2

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ deps =
256256
datastore_aiomcache: aiomcache
257257
datastore_aiomysql: aiomysql
258258
datastore_aiomysql: cryptography
259+
datastore_aiomysql: sqlalchemy<2
259260
datastore_bmemcached: python-binary-memcached
260261
datastore_cassandradriver-cassandralatest: cassandra-driver
261262
datastore_cassandradriver-cassandralatest: twisted

0 commit comments

Comments
 (0)