Skip to content

Commit 208344e

Browse files
Starlette Instrumentation Updates (#254)
* Fix starlette background task bug Co-authored-by: Uma Annamalai <[email protected]> * Formatting Co-authored-by: Uma Annamalai <[email protected]>
1 parent b891adc commit 208344e

File tree

7 files changed

+365
-109
lines changed

7 files changed

+365
-109
lines changed

newrelic/hooks/framework_starlette.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414

1515
from newrelic.api.asgi_application import wrap_asgi_application
1616
from newrelic.api.background_task import BackgroundTaskWrapper
17-
from newrelic.api.time_trace import current_trace
18-
from newrelic.api.function_trace import FunctionTraceWrapper, wrap_function_trace
17+
from newrelic.api.function_trace import FunctionTraceWrapper
18+
from newrelic.api.time_trace import current_trace, notice_error
19+
from newrelic.api.transaction import current_transaction
20+
from newrelic.common.coroutine import is_coroutine_function
1921
from newrelic.common.object_names import callable_name
20-
from newrelic.common.object_wrapper import wrap_function_wrapper, function_wrapper, FunctionWrapper
21-
from newrelic.core.trace_cache import trace_cache
22-
from newrelic.api.time_trace import notice_error
22+
from newrelic.common.object_wrapper import (
23+
FunctionWrapper,
24+
function_wrapper,
25+
wrap_function_wrapper,
26+
)
2327
from newrelic.core.config import should_ignore_error
24-
from newrelic.common.coroutine import is_coroutine_function
25-
from newrelic.api.transaction import current_transaction
28+
from newrelic.core.trace_cache import trace_cache
2629

2730

2831
def framework_details():
@@ -93,10 +96,19 @@ def wrap_request(wrapped, instance, args, kwargs):
9396
def wrap_background_method(wrapped, instance, args, kwargs):
9497
func = getattr(instance, "func", None)
9598
if func:
96-
instance.func = BackgroundTaskWrapper(func)
99+
instance.func = wrap_background_task(func)
97100
return wrapped(*args, **kwargs)
98101

99102

103+
@function_wrapper
104+
def wrap_background_task(wrapped, instance, args, kwargs):
105+
transaction = current_transaction(active_only=False)
106+
if not transaction:
107+
return BackgroundTaskWrapper(wrapped)(*args, **kwargs)
108+
else:
109+
return FunctionTraceWrapper(wrapped)(*args, **kwargs)
110+
111+
100112
async def middleware_wrapper(wrapped, instance, args, kwargs):
101113
transaction = current_transaction()
102114
if transaction:
@@ -158,7 +170,9 @@ async def wrap_exception_handler_async(coro, exc):
158170

159171
def wrap_exception_handler(wrapped, instance, args, kwargs):
160172
if is_coroutine_function(wrapped):
161-
return wrap_exception_handler_async(FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs))
173+
return wrap_exception_handler_async(
174+
FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs)
175+
)
162176
else:
163177
with RequestContext(bind_request(*args, **kwargs)):
164178
response = FunctionTraceWrapper(wrapped)(*args, **kwargs)
@@ -168,14 +182,16 @@ def wrap_exception_handler(wrapped, instance, args, kwargs):
168182

169183
def wrap_server_error_handler(wrapped, instance, args, kwargs):
170184
result = wrapped(*args, **kwargs)
171-
handler = getattr(instance, 'handler', None)
185+
handler = getattr(instance, "handler", None)
172186
if handler:
173187
instance.handler = FunctionWrapper(handler, wrap_exception_handler)
174188
return result
175189

176190

177191
def wrap_add_exception_handler(wrapped, instance, args, kwargs):
178-
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(*args, **kwargs)
192+
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(
193+
*args, **kwargs
194+
)
179195
handler = FunctionWrapper(handler, wrap_exception_handler)
180196
return wrapped(exc_class_or_status_code, handler, *args, **kwargs)
181197

@@ -207,25 +223,36 @@ def instrument_starlette_requests(module):
207223

208224

209225
def instrument_starlette_middleware_errors(module):
210-
wrap_function_wrapper(module, "ServerErrorMiddleware.__call__", error_middleware_wrapper)
226+
wrap_function_wrapper(
227+
module, "ServerErrorMiddleware.__call__", error_middleware_wrapper
228+
)
211229

212-
wrap_function_wrapper(module, "ServerErrorMiddleware.__init__", wrap_server_error_handler)
230+
wrap_function_wrapper(
231+
module, "ServerErrorMiddleware.__init__", wrap_server_error_handler
232+
)
213233

214-
wrap_function_wrapper(module, "ServerErrorMiddleware.error_response", wrap_exception_handler)
234+
wrap_function_wrapper(
235+
module, "ServerErrorMiddleware.error_response", wrap_exception_handler
236+
)
215237

216-
wrap_function_wrapper(module, "ServerErrorMiddleware.debug_response", wrap_exception_handler)
238+
wrap_function_wrapper(
239+
module, "ServerErrorMiddleware.debug_response", wrap_exception_handler
240+
)
217241

218242

219243
def instrument_starlette_exceptions(module):
220-
wrap_function_wrapper(module, "ExceptionMiddleware.__call__", error_middleware_wrapper)
244+
wrap_function_wrapper(
245+
module, "ExceptionMiddleware.__call__", error_middleware_wrapper
246+
)
221247

222-
wrap_function_wrapper(module, "ExceptionMiddleware.http_exception",
223-
wrap_exception_handler)
248+
wrap_function_wrapper(
249+
module, "ExceptionMiddleware.http_exception", wrap_exception_handler
250+
)
224251

225-
wrap_function_wrapper(module, "ExceptionMiddleware.add_exception_handler",
226-
wrap_add_exception_handler)
252+
wrap_function_wrapper(
253+
module, "ExceptionMiddleware.add_exception_handler", wrap_add_exception_handler
254+
)
227255

228256

229257
def instrument_starlette_background_task(module):
230258
wrap_function_wrapper(module, "BackgroundTask.__call__", wrap_background_method)
231-

tests/framework_starlette/_target_application.py renamed to tests/framework_starlette/_test_application.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515
from starlette.applications import Starlette
1616
from starlette.background import BackgroundTasks
17+
from starlette.exceptions import HTTPException
1718
from starlette.responses import PlainTextResponse
1819
from starlette.routing import Route
19-
from starlette.exceptions import HTTPException
2020
from testing_support.asgi_testing import AsgiTest
21-
from newrelic.api.transaction import current_transaction
21+
2222
from newrelic.api.function_trace import FunctionTrace
23+
from newrelic.api.transaction import current_transaction
2324
from newrelic.common.object_names import callable_name
2425

2526
try:
@@ -101,6 +102,7 @@ async def bg_task_async():
101102
def bg_task_non_async():
102103
pass
103104

105+
104106
routes = [
105107
Route("/index", index),
106108
Route("/418", teapot),
@@ -130,7 +132,11 @@ async def middleware_decorator(request, call_next):
130132
# Generating target applications
131133
app_name_map = {
132134
"no_error_handler": (True, False, {}),
133-
"async_error_handler_no_middleware": (False, False, {Exception: async_error_handler}),
135+
"async_error_handler_no_middleware": (
136+
False,
137+
False,
138+
{Exception: async_error_handler},
139+
),
134140
"non_async_error_handler_no_middleware": (False, False, {}),
135141
"no_middleware": (False, False, {}),
136142
"debug_no_middleware": (False, True, {}),
@@ -145,12 +151,21 @@ async def middleware_decorator(request, call_next):
145151

146152
# Instantiate app
147153
if not middleware_on:
148-
app = Starlette(debug=debug, routes=routes, exception_handlers=exception_handlers)
154+
app = Starlette(
155+
debug=debug, routes=routes, exception_handlers=exception_handlers
156+
)
149157
else:
150158
if Middleware:
151-
app = Starlette(debug=debug, routes=routes, middleware=[Middleware(middleware_factory)], exception_handlers=exception_handlers)
159+
app = Starlette(
160+
debug=debug,
161+
routes=routes,
162+
middleware=[Middleware(middleware_factory)],
163+
exception_handlers=exception_handlers,
164+
)
152165
else:
153-
app = Starlette(debug=debug, routes=routes, exception_handlers=exception_handlers)
166+
app = Starlette(
167+
debug=debug, routes=routes, exception_handlers=exception_handlers
168+
)
154169
# in earlier versions of starlette, middleware is not a legal argument on the Starlette application class
155170
# In order to keep the counts the same, we add the middleware twice using the add_middleware interface
156171
app.add_middleware(middleware_factory)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2010 New Relic, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from starlette.applications import Starlette
16+
from starlette.background import BackgroundTasks
17+
from starlette.middleware.base import BaseHTTPMiddleware
18+
from starlette.responses import PlainTextResponse
19+
from starlette.routing import Route
20+
from testing_support.asgi_testing import AsgiTest
21+
22+
23+
class ASGIStyleMiddleware:
24+
def __init__(self, app):
25+
self.app = app
26+
27+
async def __call__(self, scope, receive, send):
28+
response = await self.app(scope, receive, send)
29+
return response
30+
31+
32+
class BaseHTTPStyleMiddleware(BaseHTTPMiddleware):
33+
async def dispatch(self, request, call_next):
34+
# simple middleware that does absolutely nothing
35+
response = await call_next(request)
36+
return response
37+
38+
39+
async def run_async_bg_task(request):
40+
tasks = BackgroundTasks()
41+
tasks.add_task(async_bg_task)
42+
return PlainTextResponse("Hello, world!", background=tasks)
43+
44+
45+
async def run_sync_bg_task(request):
46+
tasks = BackgroundTasks()
47+
tasks.add_task(sync_bg_task)
48+
return PlainTextResponse("Hello, world!", background=tasks)
49+
50+
51+
async def async_bg_task():
52+
pass
53+
54+
55+
async def sync_bg_task():
56+
pass
57+
58+
59+
routes = [
60+
Route("/async", run_async_bg_task),
61+
Route("/sync", run_sync_bg_task),
62+
]
63+
64+
# Generating target applications
65+
target_application = {}
66+
67+
app = Starlette(routes=routes)
68+
app.add_middleware(ASGIStyleMiddleware)
69+
target_application["asgi"] = AsgiTest(app)
70+
71+
app = Starlette(routes=routes)
72+
app.add_middleware(BaseHTTPStyleMiddleware)
73+
target_application["basehttp"] = AsgiTest(app)
74+
75+
app = Starlette(routes=routes)
76+
target_application["none"] = AsgiTest(app)

tests/framework_starlette/conftest.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pytest
16-
1715
from testing_support.fixtures import (
1816
code_coverage_fixture,
1917
collector_agent_registration_fixture,
@@ -38,10 +36,3 @@
3836
app_name="Python Agent Test (framework_starlette)",
3937
default_settings=_default_settings,
4038
)
41-
42-
43-
@pytest.fixture(scope="session")
44-
def target_application():
45-
import _target_application
46-
47-
return _target_application.target_application

0 commit comments

Comments
 (0)