Skip to content

Commit e8af7a3

Browse files
authored
Respect provided tracer provider when instrumenting SQLAlchemy (#728)
* respect provided tracer provider when instrumenting sqlalchemy This change updates the SQLALchemyInstrumentor to respect the tracer provider that is passed in through the kwargs when patching the `create_engine` functionality provided by SQLAlchemy. Previously, it would default to the global tracer provider. * feedback: pass in tracer_provider directly rather than kwargs * feedback: update changelog * build: lint
1 parent 5105820 commit e8af7a3

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
([#713](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/713))
1313
- `opentelemetry-sdk-extension-aws` Move AWS X-Ray Propagator into its own `opentelemetry-propagators-aws` package
1414
([#720](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/720))
15+
- `opentelemetry-instrumentation-sqlalchemy` Respect provided tracer provider when instrumenting SQLAlchemy
16+
([#728](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/728))
1517

1618

1719
### Changed

instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,23 @@ def _instrument(self, **kwargs):
8888
Returns:
8989
An instrumented engine if passed in as an argument, None otherwise.
9090
"""
91-
_w("sqlalchemy", "create_engine", _wrap_create_engine)
92-
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
91+
tracer_provider = kwargs.get("tracer_provider")
92+
_w("sqlalchemy", "create_engine", _wrap_create_engine(tracer_provider))
93+
_w(
94+
"sqlalchemy.engine",
95+
"create_engine",
96+
_wrap_create_engine(tracer_provider),
97+
)
9398
if parse_version(sqlalchemy.__version__).release >= (1, 4):
9499
_w(
95100
"sqlalchemy.ext.asyncio",
96101
"create_async_engine",
97-
_wrap_create_async_engine,
102+
_wrap_create_async_engine(tracer_provider),
98103
)
99104

100105
if kwargs.get("engine") is not None:
101106
return EngineTracer(
102-
_get_tracer(
103-
kwargs.get("engine"), kwargs.get("tracer_provider")
104-
),
107+
_get_tracer(kwargs.get("engine"), tracer_provider),
105108
kwargs.get("engine"),
106109
)
107110
return None

instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,30 @@ def _get_tracer(engine, tracer_provider=None):
4242
)
4343

4444

45-
# pylint: disable=unused-argument
46-
def _wrap_create_async_engine(func, module, args, kwargs):
47-
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
48-
object that will listen to SQLAlchemy events.
49-
"""
50-
engine = func(*args, **kwargs)
51-
EngineTracer(_get_tracer(engine), engine.sync_engine)
52-
return engine
45+
def _wrap_create_async_engine(tracer_provider=None):
46+
# pylint: disable=unused-argument
47+
def _wrap_create_async_engine_internal(func, module, args, kwargs):
48+
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
49+
object that will listen to SQLAlchemy events.
50+
"""
51+
engine = func(*args, **kwargs)
52+
EngineTracer(_get_tracer(engine, tracer_provider), engine.sync_engine)
53+
return engine
5354

55+
return _wrap_create_async_engine_internal
5456

55-
# pylint: disable=unused-argument
56-
def _wrap_create_engine(func, module, args, kwargs):
57-
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
58-
object that will listen to SQLAlchemy events.
59-
"""
60-
engine = func(*args, **kwargs)
61-
EngineTracer(_get_tracer(engine), engine)
62-
return engine
57+
58+
def _wrap_create_engine(tracer_provider=None):
59+
# pylint: disable=unused-argument
60+
def _wrap_create_engine_internal(func, module, args, kwargs):
61+
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
62+
object that will listen to SQLAlchemy events.
63+
"""
64+
engine = func(*args, **kwargs)
65+
EngineTracer(_get_tracer(engine, tracer_provider), engine)
66+
return engine
67+
68+
return _wrap_create_engine_internal
6369

6470

6571
class EngineTracer:

instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from opentelemetry import trace
2222
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
23+
from opentelemetry.sdk.resources import Resource
24+
from opentelemetry.sdk.trace import TracerProvider, export
2325
from opentelemetry.test.test_base import TestBase
2426

2527

@@ -95,6 +97,37 @@ def test_create_engine_wrapper(self):
9597
self.assertEqual(spans[0].name, "SELECT :memory:")
9698
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
9799

100+
def test_custom_tracer_provider(self):
101+
provider = TracerProvider(
102+
resource=Resource.create(
103+
{
104+
"service.name": "test",
105+
"deployment.environment": "env",
106+
"service.version": "1234",
107+
},
108+
),
109+
)
110+
provider.add_span_processor(
111+
export.SimpleSpanProcessor(self.memory_exporter)
112+
)
113+
114+
SQLAlchemyInstrumentor().instrument(tracer_provider=provider)
115+
from sqlalchemy import create_engine # pylint: disable-all
116+
117+
engine = create_engine("sqlite:///:memory:")
118+
cnx = engine.connect()
119+
cnx.execute("SELECT 1 + 1;").fetchall()
120+
spans = self.memory_exporter.get_finished_spans()
121+
122+
self.assertEqual(len(spans), 1)
123+
self.assertEqual(spans[0].resource.attributes["service.name"], "test")
124+
self.assertEqual(
125+
spans[0].resource.attributes["deployment.environment"], "env"
126+
)
127+
self.assertEqual(
128+
spans[0].resource.attributes["service.version"], "1234"
129+
)
130+
98131
@pytest.mark.skipif(
99132
not sqlalchemy.__version__.startswith("1.4"),
100133
reason="only run async tests for 1.4",

0 commit comments

Comments
 (0)