Skip to content

Commit acf43cb

Browse files
committed
Starlette GraphQL Context Propagation (#361)
* Add graphql sync tests to fastapi * Add starlette context propagation * Move starlette tests to starlette not fastapi
1 parent 49c95c3 commit acf43cb

File tree

10 files changed

+193
-86
lines changed

10 files changed

+193
-86
lines changed

newrelic/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,6 +2453,11 @@ def _process_module_builtin_defaults():
24532453
"newrelic.hooks.framework_starlette",
24542454
"instrument_starlette_background_task",
24552455
)
2456+
_process_module_definition(
2457+
"starlette.concurrency",
2458+
"newrelic.hooks.framework_starlette",
2459+
"instrument_starlette_concurrency",
2460+
)
24562461

24572462
_process_module_definition(
24582463
"strawberry.asgi",

newrelic/core/context.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This module implements utilities for context propagation for tracing across threads.
17+
"""
18+
19+
from newrelic.common.object_wrapper import function_wrapper
20+
from newrelic.core.trace_cache import trace_cache
21+
22+
class ContextOf(object):
23+
def __init__(self, trace_cache_id):
24+
self.trace_cache = trace_cache()
25+
self.trace = self.trace_cache._cache.get(trace_cache_id)
26+
self.thread_id = None
27+
self.restore = None
28+
29+
def __enter__(self):
30+
if self.trace:
31+
self.thread_id = self.trace_cache.current_thread_id()
32+
self.restore = self.trace_cache._cache.get(self.thread_id)
33+
self.trace_cache._cache[self.thread_id] = self.trace
34+
return self
35+
36+
def __exit__(self, exc, value, tb):
37+
if self.restore:
38+
self.trace_cache._cache[self.thread_id] = self.restore
39+
40+
41+
async def context_wrapper_async(awaitable, trace_cache_id):
42+
with ContextOf(trace_cache_id):
43+
return await awaitable
44+
45+
46+
def context_wrapper(func, trace_cache_id):
47+
@function_wrapper
48+
def _context_wrapper(wrapped, instance, args, kwargs):
49+
with ContextOf(trace_cache_id):
50+
return wrapped(*args, **kwargs)
51+
52+
return _context_wrapper(func)
53+
54+
55+
def current_thread_id():
56+
return trace_cache().current_thread_id()

newrelic/hooks/adapter_asgiref.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,12 @@
11
from newrelic.common.object_wrapper import wrap_function_wrapper
22
from newrelic.core.trace_cache import trace_cache
3+
from newrelic.core.context import context_wrapper_async, ContextOf
34

45

56
def _bind_thread_handler(loop, source_task, *args, **kwargs):
67
return source_task
78

89

9-
class ContextOf(object):
10-
def __init__(self, trace_cache_id):
11-
self.trace_cache = trace_cache()
12-
self.trace = self.trace_cache._cache.get(trace_cache_id)
13-
self.thread_id = None
14-
self.restore = None
15-
16-
def __enter__(self):
17-
if self.trace:
18-
self.thread_id = self.trace_cache.current_thread_id()
19-
self.restore = self.trace_cache._cache.get(self.thread_id)
20-
self.trace_cache._cache[self.thread_id] = self.trace
21-
return self
22-
23-
def __exit__(self, exc, value, tb):
24-
if self.restore:
25-
self.trace_cache._cache[self.thread_id] = self.restore
26-
27-
28-
async def context_wrapper_async(awaitable, trace_cache_id):
29-
with ContextOf(trace_cache_id):
30-
return await awaitable
31-
32-
3310
def thread_handler_wrapper(wrapped, instance, args, kwargs):
3411
task = _bind_thread_handler(*args, **kwargs)
3512
with ContextOf(id(task)):

newrelic/hooks/framework_graphql.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def wrap_execute_operation(wrapped, instance, args, kwargs):
113113
_logger.warning(
114114
"Runtime instrumentation warning. GraphQL operation found without active GraphQLOperationTrace."
115115
)
116-
breakpoint()
117116
return wrapped(*args, **kwargs)
118117

119118
try:

newrelic/hooks/framework_starlette.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
wrap_function_wrapper,
2626
)
2727
from newrelic.core.config import should_ignore_error
28+
from newrelic.core.context import context_wrapper, current_thread_id
2829
from newrelic.core.trace_cache import trace_cache
2930

3031

@@ -204,6 +205,23 @@ def error_middleware_wrapper(wrapped, instance, args, kwargs):
204205
return FunctionTraceWrapper(wrapped)(*args, **kwargs)
205206

206207

208+
def bind_run_in_threadpool(func, *args, **kwargs):
209+
return func, args, kwargs
210+
211+
212+
async def wrap_run_in_threadpool(wrapped, instance, args, kwargs):
213+
transaction = current_transaction()
214+
trace = current_trace()
215+
216+
if not transaction or not trace:
217+
return await wrapped(*args, **kwargs)
218+
219+
func, args, kwargs = bind_run_in_threadpool(*args, **kwargs)
220+
func = context_wrapper(func, current_thread_id())
221+
222+
return await wrapped(func, *args, **kwargs)
223+
224+
207225
def instrument_starlette_applications(module):
208226
framework = framework_details()
209227
version_info = tuple(int(v) for v in framework[1].split(".", 3)[:3])
@@ -256,3 +274,7 @@ def instrument_starlette_exceptions(module):
256274

257275
def instrument_starlette_background_task(module):
258276
wrap_function_wrapper(module, "BackgroundTask.__call__", wrap_background_method)
277+
278+
279+
def instrument_starlette_concurrency(module):
280+
wrap_function_wrapper(module, "run_in_threadpool", wrap_run_in_threadpool)

tests/framework_fastapi/_target_application.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
# limitations under the License.
1414

1515
from fastapi import FastAPI
16-
from graphene import ObjectType, String, Schema
17-
from graphql.execution.executors.asyncio import AsyncioExecutor
18-
from starlette.graphql import GraphQLApp
1916

2017
from newrelic.api.transaction import current_transaction
2118
from testing_support.asgi_testing import AsgiTest
@@ -35,13 +32,4 @@ async def non_sync():
3532
return {}
3633

3734

38-
class Query(ObjectType):
39-
hello = String()
40-
41-
def resolve_hello(self, info):
42-
return "Hello!"
43-
44-
45-
app.add_route("/graphql", GraphQLApp(executor_class=AsyncioExecutor, schema=Schema(query=Query)))
46-
4735
target_application = AsgiTest(app)

tests/framework_fastapi/test_application.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414

1515
import pytest
16-
from testing_support.fixtures import dt_enabled, validate_transaction_metrics
17-
from testing_support.validators.validate_span_events import validate_span_events
16+
from testing_support.fixtures import validate_transaction_metrics
1817

1918

2019
@pytest.mark.parametrize("endpoint,transaction_name", (
@@ -29,49 +28,3 @@ def _test():
2928
assert response.status == 200
3029

3130
_test()
32-
33-
34-
@dt_enabled
35-
def test_graphql_endpoint(app):
36-
from graphql import __version__ as version
37-
38-
FRAMEWORK_METRICS = [
39-
("Python/Framework/GraphQL/%s" % version, 1),
40-
]
41-
_test_scoped_metrics = [
42-
("GraphQL/resolve/GraphQL/hello", 1),
43-
("GraphQL/operation/GraphQL/query/<anonymous>/hello", 1),
44-
]
45-
_test_unscoped_metrics = [
46-
("GraphQL/all", 1),
47-
("GraphQL/GraphQL/all", 1),
48-
("GraphQL/allWeb", 1),
49-
("GraphQL/GraphQL/allWeb", 1),
50-
] + _test_scoped_metrics
51-
52-
_expected_query_operation_attributes = {
53-
"graphql.operation.type": "query",
54-
"graphql.operation.name": "<anonymous>",
55-
"graphql.operation.query": "{ hello }",
56-
}
57-
_expected_query_resolver_attributes = {
58-
"graphql.field.name": "hello",
59-
"graphql.field.parentType": "Query",
60-
"graphql.field.path": "hello",
61-
"graphql.field.returnType": "String",
62-
}
63-
64-
@validate_span_events(exact_agents=_expected_query_operation_attributes)
65-
@validate_span_events(exact_agents=_expected_query_resolver_attributes)
66-
@validate_transaction_metrics(
67-
"query/<anonymous>/hello",
68-
"GraphQL",
69-
scoped_metrics=_test_scoped_metrics,
70-
rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS,
71-
)
72-
def _test():
73-
response = app.make_request("POST", "/graphql", params="query=%7B%20hello%20%7D")
74-
assert response.status == 200
75-
assert "Hello!" in response.body.decode("utf-8")
76-
77-
_test()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from starlette.applications import Starlette
16+
from starlette.routing import Route
17+
from testing_support.asgi_testing import AsgiTest
18+
19+
from graphene import ObjectType, String, Schema
20+
from graphql.execution.executors.asyncio import AsyncioExecutor
21+
from starlette.graphql import GraphQLApp
22+
23+
24+
class Query(ObjectType):
25+
hello = String()
26+
27+
def resolve_hello(self, info):
28+
return "Hello!"
29+
30+
31+
routes = [
32+
Route("/async", GraphQLApp(executor_class=AsyncioExecutor, schema=Schema(query=Query))),
33+
Route("/sync", GraphQLApp(schema=Schema(query=Query))),
34+
]
35+
36+
app = Starlette(routes=routes)
37+
target_application = AsgiTest(app)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import pytest
17+
from testing_support.fixtures import dt_enabled, validate_transaction_metrics
18+
from testing_support.validators.validate_span_events import validate_span_events
19+
20+
@pytest.fixture(scope="session")
21+
def target_application():
22+
import _test_graphql
23+
24+
return _test_graphql.target_application
25+
26+
@dt_enabled
27+
@pytest.mark.parametrize("endpoint", ("/async", "/sync"))
28+
def test_graphql_metrics_and_attrs(target_application, endpoint):
29+
from graphql import __version__ as version
30+
31+
FRAMEWORK_METRICS = [
32+
("Python/Framework/GraphQL/%s" % version, 1),
33+
]
34+
_test_scoped_metrics = [
35+
("GraphQL/resolve/GraphQL/hello", 1),
36+
("GraphQL/operation/GraphQL/query/<anonymous>/hello", 1),
37+
]
38+
_test_unscoped_metrics = [
39+
("GraphQL/all", 1),
40+
("GraphQL/GraphQL/all", 1),
41+
("GraphQL/allWeb", 1),
42+
("GraphQL/GraphQL/allWeb", 1),
43+
] + _test_scoped_metrics
44+
45+
_expected_query_operation_attributes = {
46+
"graphql.operation.type": "query",
47+
"graphql.operation.name": "<anonymous>",
48+
"graphql.operation.query": "{ hello }",
49+
}
50+
_expected_query_resolver_attributes = {
51+
"graphql.field.name": "hello",
52+
"graphql.field.parentType": "Query",
53+
"graphql.field.path": "hello",
54+
"graphql.field.returnType": "String",
55+
}
56+
57+
@validate_span_events(exact_agents=_expected_query_operation_attributes)
58+
@validate_span_events(exact_agents=_expected_query_resolver_attributes)
59+
@validate_transaction_metrics(
60+
"query/<anonymous>/hello",
61+
"GraphQL",
62+
scoped_metrics=_test_scoped_metrics,
63+
rollup_metrics=_test_unscoped_metrics + FRAMEWORK_METRICS,
64+
)
65+
def _test():
66+
response = target_application.make_request("POST", endpoint, body=json.dumps({"query": "{ hello }"}), headers={"Content-Type": "application/json"})
67+
assert response.status == 200
68+
assert "Hello!" in response.body.decode("utf-8")
69+
70+
_test()

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@ deps =
247247
framework_falcon-falcon0200: falcon<2.1
248248
framework_falcon-falconmaster: https://github.com/falconry/falcon/archive/master.zip
249249
framework_fastapi: fastapi
250-
framework_fastapi: graphene
251250
framework_fastapi: asyncio
252251
framework_flask: Flask-Compress
253252
framework_flask-flask0012: flask<0.13
@@ -283,6 +282,7 @@ deps =
283282
framework_sanic-sanic210300: sanic<21.3.1
284283
framework_sanic-saniclatest: sanic
285284
framework_sanic-sanic{1812,190301,1906}: aiohttp
285+
framework_starlette: graphene
286286
framework_starlette-starlette0014: starlette<0.15
287287
framework_starlette-starlettelatest: starlette
288288
framework_strawberry: starlette

0 commit comments

Comments
 (0)