Skip to content

Commit e127e6e

Browse files
bharlinghaotianw465
authored andcommitted
Support binding connection in sqlalchemy as well as engine (#78)
1 parent 1010db3 commit e127e6e

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

aws_xray_sdk/ext/sqlalchemy/util/decorators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from future.standard_library import install_aliases
55
install_aliases()
66
from urllib.parse import urlparse, uses_netloc
7+
from sqlalchemy.engine.base import Connection
78

89

910
def decorate_all_functions(function_decorator):
@@ -86,7 +87,11 @@ def wrapper(*args, **kw):
8687
# }
8788
def parse_bind(bind):
8889
"""Parses a connection string and creates SQL trace metadata"""
89-
m = re.match(r"Engine\((.*?)\)", str(bind))
90+
if isinstance(bind, Connection):
91+
engine = bind.engine
92+
else:
93+
engine = bind
94+
m = re.match(r"Engine\((.*?)\)", str(engine))
9095
if m is not None:
9196
u = urlparse(m.group(1))
9297
# Add Scheme to uses_netloc or // will be missing from url.

tests/ext/sqlalchemy/test_query.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ class User(Base):
2121

2222

2323
@pytest.fixture()
24-
def session():
24+
def engine():
25+
return create_engine('sqlite:///:memory:')
26+
27+
28+
@pytest.fixture()
29+
def session(engine):
2530
"""Test Fixture to Create DataBase Tables and start a trace segment"""
2631
engine = create_engine('sqlite:///:memory:')
2732
xray_recorder.configure(service='test', sampling=False, context=Context())
@@ -35,6 +40,21 @@ def session():
3540
xray_recorder.clear_trace_entities()
3641

3742

43+
@pytest.fixture()
44+
def connection(engine):
45+
conn = engine.connect()
46+
xray_recorder.configure(service='test', sampling=False, context=Context())
47+
xray_recorder.clear_trace_entities()
48+
xray_recorder.begin_segment('SQLAlchemyTest')
49+
Session = XRaySessionMaker(bind=conn)
50+
Base.metadata.create_all(engine)
51+
session = Session()
52+
yield session
53+
xray_recorder.end_segment()
54+
xray_recorder.clear_trace_entities()
55+
56+
57+
3858
def test_all(capsys, session):
3959
""" Test calling all() on get all records.
4060
Verify we run the query and return the SQL as metdata"""
@@ -46,6 +66,14 @@ def test_all(capsys, session):
4666
assert subsegment['sql']['url']
4767

4868

69+
def test_supports_connection(capsys, connection):
70+
""" Test that XRaySessionMaker supports connection as well as engine"""
71+
connection.query(User).all()
72+
subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy',
73+
'sqlalchemy.orm.query.all')
74+
assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.all'
75+
76+
4977
def test_add(capsys, session):
5078
""" Test calling add() on insert a row.
5179
Verify we that we capture trace for the add"""

0 commit comments

Comments
 (0)