Skip to content

Commit 088fdb1

Browse files
committed
Reviewer suggestions and more tests
1 parent c77a0a2 commit 088fdb1

File tree

6 files changed

+475
-93
lines changed

6 files changed

+475
-93
lines changed

newrelic/hooks/application_celery.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,26 @@ def task_info(instance, *args, **kwargs):
6464
return task_name, task_source
6565

6666

67+
# def create_DT_headers_and_attach(wrapped, transaction):
68+
# try:
69+
# request = wrapped and wrapped.request_stack and wrapped.request_stack.top
70+
# original_headers = getattr(request, "headers", None) or vars(request)
71+
# dt_headers = MessageTrace.generate_request_headers(transaction)
72+
# if dt_headers:
73+
# if not original_headers:
74+
# wrapped.request.headers = dict(dt_headers)
75+
# else:
76+
# dt_headers.update(original_headers)
77+
# wrapped.request.headers = dt_headers = dict(dt_headers)
78+
# # original_headers.update(dict(dt_headers))
79+
# # wrapped.request.headers = original_headers
80+
81+
# except Exception:
82+
# pass
83+
84+
# return wrapped
85+
86+
6787
def wrap_task_call(wrapped, instance, args, kwargs):
6888
transaction = current_transaction(active_only=False)
6989

@@ -104,22 +124,23 @@ def wrap_task_call(wrapped, instance, args, kwargs):
104124
try:
105125
# Headers on earlier versions of Celery may end up as attributes
106126
# on the request context instead of as custom headers. Handle this
107-
# by defaulting to using vars() if headers is not available
108-
request = wrapped.request
127+
# by defaulting to using `vars()` if headers is not available
128+
129+
# If there is no request, the request property will return
130+
# a new instance of `celery.Context()` instead of `None`, so
131+
# this will be handled by accessing the request_stack directly.
132+
request = wrapped and wrapped.request_stack and wrapped.request_stack.top
109133
headers = getattr(request, "headers", None) or vars(request)
110134

111135
settings = transaction.settings
112136
if headers is not None and settings is not None:
113137
if settings.distributed_tracing.enabled:
138+
# Generate DT headers if they do not already exist in the incoming request
114139
if not transaction.accept_distributed_trace_headers(headers, transport_type="AMQP"):
115140
try:
116141
dt_headers = MessageTrace.generate_request_headers(transaction)
117142
if dt_headers:
118-
if not headers:
119-
wrapped.request.headers = dict(dt_headers)
120-
else:
121-
headers.update(dict(dt_headers))
122-
wrapped.request.headers = headers
143+
headers.update(dict(dt_headers))
123144
except Exception:
124145
pass
125146
elif transaction.settings.cross_application_tracer.enabled:
@@ -143,7 +164,6 @@ def run(self, *args, **kwargs):
143164
task = bound_args.get("task", None)
144165

145166
task = TaskWrapper(task, wrap_task_call)
146-
task.__module__ = wrapped.__module__ # Ensure module is set for monkeypatching detection
147167
bound_args["task"] = task
148168

149169
return wrapped(**bound_args)
@@ -193,22 +213,23 @@ def wrapper(wrapped, instance, args, kwargs):
193213
try:
194214
# Headers on earlier versions of Celery may end up as attributes
195215
# on the request context instead of as custom headers. Handle this
196-
# by defaulting to using vars() if headers is not available
197-
request = instance.request
216+
# by defaulting to using `vars()` if headers is not available
217+
218+
# If there is no request, the request property will return
219+
# a new instance of `celery.Context()` instead of `None`, so
220+
# this will be handled by accessing the request_stack directly.
221+
request = instance and instance.request_stack and instance.request_stack.top
198222
headers = getattr(request, "headers", None) or vars(request)
199223

200224
settings = transaction.settings
201225
if headers is not None and settings is not None:
202226
if settings.distributed_tracing.enabled:
227+
# Generate DT headers if they do not already exist in the incoming request
203228
if not transaction.accept_distributed_trace_headers(headers, transport_type="AMQP"):
204229
try:
205230
dt_headers = MessageTrace.generate_request_headers(transaction)
206231
if dt_headers:
207-
if not headers:
208-
instance.request.headers = dict(dt_headers)
209-
else:
210-
headers.update(dict(dt_headers))
211-
instance.request.headers = headers
232+
headers.update(dict(dt_headers))
212233
except Exception:
213234
pass
214235
elif transaction.settings.cross_application_tracer.enabled:
@@ -238,7 +259,7 @@ def wrapper(wrapped, instance, args, kwargs):
238259
# Celery has included a monkey-patching provision which did not perform this
239260
# optimization on functions that were monkey-patched. Unfortunately, our
240261
# wrappers are too transparent for celery to detect that they've even been
241-
# monky-patched. To circumvent this, we set the __module__ of our wrapped task
262+
# monkey-patched. To circumvent this, we set the __module__ of our wrapped task
242263
# to this file which causes celery to properly detect that it has been patched.
243264
#
244265
# For versions of celery 2.5.3 to 2.5.5
@@ -252,7 +273,6 @@ def run(self, *args, **kwargs):
252273
return self.__call__(*args, **kwargs)
253274

254275
wrapped_task = TaskWrapper(wrapped, wrapper)
255-
# Reset __module__ to be less transparent so celery detects our monkey-patching
256276
wrapped_task.__module__ = CeleryTaskWrapper.__module__
257277

258278
return wrapped_task
@@ -262,7 +282,7 @@ def instrument_celery_local(module):
262282
if hasattr(module, "Proxy"):
263283
# This is used in the case where the function is
264284
# called directly on the Proxy object (rather than
265-
# using "delay" or "apply_async")
285+
# using `delay` or `apply_async`)
266286
module.Proxy.__call__ = CeleryTaskWrapper(module.Proxy.__call__)
267287

268288

tests/application_celery/_target_application.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,33 @@
2626
)
2727

2828

29+
@app.task
30+
def add(x, y):
31+
return x + y
32+
33+
34+
@app.task
35+
def tsum(nums):
36+
return sum(nums)
37+
38+
39+
@app.task
40+
def nested_add(x, y):
41+
return add(x, y)
42+
43+
44+
@shared_task
45+
def shared_task_add(x, y):
46+
return x + y
47+
48+
2949
class CustomCeleryTaskWithSuper(Task):
3050
def __call__(self, *args, **kwargs):
3151
transaction = current_transaction()
3252
if transaction:
3353
transaction.add_custom_attribute("custom_task_attribute", "Called with super")
3454
return super().__call__(*args, **kwargs)
3555

36-
3756
class CustomCeleryTaskWithRun(Task):
3857
def __call__(self, *args, **kwargs):
3958
transaction = current_transaction()
@@ -42,31 +61,11 @@ def __call__(self, *args, **kwargs):
4261
return self.run(*args, **kwargs)
4362

4463

45-
@app.task
46-
def add(x, y):
47-
return x + y
48-
49-
5064
@app.task(base=CustomCeleryTaskWithSuper)
5165
def add_with_super(x, y):
5266
return x + y
5367

5468

5569
@app.task(base=CustomCeleryTaskWithRun)
5670
def add_with_run(x, y):
57-
return x + y
58-
59-
60-
@app.task
61-
def tsum(nums):
62-
return sum(nums)
63-
64-
65-
@app.task
66-
def nested_add(x, y):
67-
return add(x, y)
68-
69-
70-
@shared_task
71-
def shared_task_add(x, y):
72-
return x + y
71+
return x + y

tests/application_celery/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import pytest
16+
from celery.app.trace import setup_worker_optimizations, reset_worker_optimizations
1517
from testing_support.fixtures import collector_agent_registration_fixture, collector_available_fixture
1618

1719
_default_settings = {
@@ -27,7 +29,6 @@
2729
app_name="Python Agent Test (application_celery)", default_settings=_default_settings
2830
)
2931

30-
3132
@pytest.fixture(scope="session")
3233
def celery_config():
3334
# Used by celery pytest plugin to configure Celery instance
@@ -43,3 +44,12 @@ def celery_worker_parameters():
4344
@pytest.fixture(scope="session", autouse=True)
4445
def celery_worker_available(celery_session_worker):
4546
return celery_session_worker
47+
48+
49+
@pytest.fixture(scope="session", autouse=True, params=[False, True], ids=["unpatched", "patched"])
50+
def with_worker_optimizations(request, celery_worker_available):
51+
if request.param:
52+
setup_worker_optimizations(celery_worker_available.app)
53+
54+
yield request.param
55+
reset_worker_optimizations()

tests/application_celery/test_distributed_tracing.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics
2020

2121
from newrelic.api.background_task import background_task
22+
from newrelic.api.transaction import insert_distributed_trace_headers
2223

2324

2425
@pytest.mark.parametrize("dt_enabled", [True, False])
25-
def test_celery_task_distributed_tracing_inside_background_task(dt_enabled):
26+
def test_DT_inside_transaction_delay(dt_enabled):
2627
@override_application_settings({"distributed_tracing.enabled": dt_enabled})
2728
@validate_transaction_metrics(
2829
name="_target_application.add",
@@ -35,7 +36,7 @@ def test_celery_task_distributed_tracing_inside_background_task(dt_enabled):
3536
index=-2,
3637
)
3738
@validate_transaction_metrics(
38-
name="test_distributed_tracing:test_celery_task_distributed_tracing_inside_background_task.<locals>._test",
39+
name="test_distributed_tracing:test_DT_inside_transaction_delay.<locals>._test",
3940
rollup_metrics=[
4041
("Supportability/TraceContext/Create/Success", 1 if dt_enabled else None),
4142
("Supportability/DistributedTrace/CreatePayload/Success", 1 if dt_enabled else None),
@@ -46,15 +47,15 @@ def test_celery_task_distributed_tracing_inside_background_task(dt_enabled):
4647
# One for the background task, one for the Celery task. Runs in different processes.
4748
@background_task()
4849
def _test():
49-
result = add.apply_async((1, 2))
50+
result = add.delay(1, 2)
5051
result = result.get()
5152
assert result == 3
5253

5354
_test()
5455

5556

5657
@pytest.mark.parametrize("dt_enabled", [True, False])
57-
def test_celery_task_distributed_tracing_outside_background_task(dt_enabled):
58+
def test_DT_outside_transaction_delay(dt_enabled):
5859
@override_application_settings({"distributed_tracing.enabled": dt_enabled})
5960
@validate_transaction_metrics(
6061
name="_target_application.add",
@@ -67,21 +68,23 @@ def test_celery_task_distributed_tracing_outside_background_task(dt_enabled):
6768
)
6869
@validate_transaction_count(1)
6970
def _test():
70-
result = add.apply_async((1, 2))
71+
result = add.delay(1, 2)
7172
result = result.get()
7273
assert result == 3
7374

7475
_test()
7576

76-
77-
# In this case, the background task creating the transaction
78-
# has not generated a distributed trace header, so the Celery
79-
# task will not have a distributed trace header to accept.
8077
@pytest.mark.parametrize("dt_enabled", [True, False])
81-
def test_celery_task_distributed_tracing_inside_background_task_apply(dt_enabled):
78+
def test_DT_inside_transaction_apply(dt_enabled):
8279
@override_application_settings({"distributed_tracing.enabled": dt_enabled})
8380
@validate_transaction_metrics(
84-
name="test_distributed_tracing:test_celery_task_distributed_tracing_inside_background_task_apply.<locals>._test",
81+
name="test_distributed_tracing:test_DT_inside_transaction_apply.<locals>._test",
82+
rollup_metrics=[
83+
("Function/_target_application.add", 1),
84+
],
85+
scoped_metrics=[
86+
("Function/_target_application.add", 1),
87+
],
8588
background_task=True,
8689
)
8790
@validate_transaction_count(1) # In the same process, so only one transaction
@@ -95,7 +98,34 @@ def _test():
9598

9699

97100
@pytest.mark.parametrize("dt_enabled", [True, False])
98-
def test_celery_task_distributed_tracing_outside_background_task_apply(dt_enabled):
101+
def test_DT_inside_transaction_apply_with_added_headers(dt_enabled):
102+
@override_application_settings({"distributed_tracing.enabled": dt_enabled})
103+
@validate_transaction_metrics(
104+
name="test_distributed_tracing:test_DT_inside_transaction_apply_with_added_headers.<locals>._test",
105+
rollup_metrics=[
106+
("Function/_target_application.add", 1),
107+
("Supportability/TraceContext/Create/Success", 1 if dt_enabled else None),
108+
("Supportability/DistributedTrace/CreatePayload/Success", 1 if dt_enabled else None),
109+
],
110+
scoped_metrics=[
111+
("Function/_target_application.add", 1),
112+
],
113+
background_task=True,
114+
)
115+
@validate_transaction_count(1) # In the same process, so only one transaction
116+
@background_task()
117+
def _test():
118+
headers = []
119+
insert_distributed_trace_headers(headers)
120+
result = add.apply((1, 2), headers=headers)
121+
result = result.get()
122+
assert result == 3
123+
124+
_test()
125+
126+
127+
@pytest.mark.parametrize("dt_enabled", [True, False])
128+
def test_DT_outside_transaction_apply(dt_enabled):
99129
@override_application_settings({"distributed_tracing.enabled": dt_enabled})
100130
@validate_transaction_metrics(
101131
name="_target_application.add",
@@ -113,3 +143,41 @@ def _test():
113143
assert result == 3
114144

115145
_test()
146+
147+
148+
@pytest.mark.parametrize("dt_enabled", [True, False])
149+
def test_DT_inside_transaction__call__(dt_enabled):
150+
@override_application_settings({"distributed_tracing.enabled": dt_enabled})
151+
@validate_transaction_metrics(
152+
name="test_distributed_tracing:test_DT_inside_transaction__call__.<locals>._test",
153+
rollup_metrics=[
154+
("Function/_target_application.add", 1),
155+
],
156+
scoped_metrics=[
157+
("Function/_target_application.add", 1),
158+
],
159+
background_task=True,
160+
)
161+
@validate_transaction_count(1) # In the same process, so only one transaction
162+
@background_task()
163+
def _test():
164+
result = add(1, 2)
165+
assert result == 3
166+
167+
_test()
168+
169+
170+
@pytest.mark.parametrize("dt_enabled", [True, False])
171+
def test_DT_outside_transaction__call__(dt_enabled):
172+
@override_application_settings({"distributed_tracing.enabled": dt_enabled})
173+
@validate_transaction_metrics(
174+
name="_target_application.add",
175+
group="Celery",
176+
background_task=True,
177+
)
178+
@validate_transaction_count(1) # In the same process, so only one transaction
179+
def _test():
180+
result = add(1, 2)
181+
assert result == 3
182+
183+
_test()

0 commit comments

Comments
 (0)