Skip to content

Commit 7a643b4

Browse files
set checkpoint in the caller of extract
1 parent ef52dc1 commit 7a643b4

File tree

4 files changed

+107
-119
lines changed

4 files changed

+107
-119
lines changed

datadog_lambda/tracing.py

Lines changed: 82 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -216,26 +216,70 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
216216
"""
217217

218218
# EventBridge => SQS
219+
is_sqs = False
219220
try:
220221
context = _extract_context_from_eventbridge_sqs_event(event)
221222
if _is_context_complete(context):
222-
return context
223+
return context, None, None
223224
except Exception:
224225
logger.debug("Failed extracting context as EventBridge to SQS.")
225226

226227
try:
227-
dd_json_data, _ = extract_dd_context_from_sqs_or_sns_event(event)
228+
first_record = event.get("Records")[0]
229+
arn = first_record.get("eventSourceARN", "")
230+
if arn:
231+
is_sqs = True
232+
dd_json_data = None
233+
# logic to deal with SNS => SQS event
234+
if "body" in first_record:
235+
body_str = first_record.get("body")
236+
try:
237+
body = json.loads(body_str)
238+
if body.get("Type", "") == "Notification" and "TopicArn" in body:
239+
logger.debug("Found SNS message inside SQS event")
240+
first_record = get_first_record(create_sns_event(body))
241+
except Exception:
242+
pass
243+
244+
msg_attributes = first_record.get("messageAttributes")
245+
if msg_attributes is None:
246+
sns_record = first_record.get("Sns") or {}
247+
if not is_sqs:
248+
arn = sns_record.get("TopicArn", "")
249+
msg_attributes = sns_record.get("MessageAttributes") or {}
250+
dd_payload = msg_attributes.get("_datadog")
251+
if dd_payload:
252+
# SQS uses dataType and binaryValue/stringValue
253+
# SNS uses Type and Value
254+
dd_json_data_type = dd_payload.get("Type") or dd_payload.get("dataType")
255+
if dd_json_data_type == "Binary":
256+
import base64
257+
258+
dd_json_data = dd_payload.get("binaryValue") or dd_payload.get("Value")
259+
if dd_json_data:
260+
dd_json_data = base64.b64decode(dd_json_data)
261+
elif dd_json_data_type == "String":
262+
dd_json_data = dd_payload.get("stringValue") or dd_payload.get("Value")
263+
else:
264+
logger.debug(
265+
"Datadog Lambda Python only supports extracting trace"
266+
"context from String or Binary SQS/SNS message attributes"
267+
)
228268
if dd_json_data:
229269
dd_data = json.loads(dd_json_data)
230270

231271
if is_step_function_event(dd_data):
232272
try:
233-
return extract_context_from_step_functions(dd_data, None)
273+
return (
274+
extract_context_from_step_functions(dd_data, None),
275+
None,
276+
None,
277+
)
234278
except Exception:
235279
logger.debug(
236280
"Failed to extract Step Functions context from SQS/SNS event."
237281
)
238-
return propagator.extract(dd_data)
282+
return propagator.extract(dd_data), dd_data, arn
239283
else:
240284
# Handle case where trace context is injected into attributes.AWSTraceHeader
241285
# example: Root=1-654321ab-000000001234567890abcdef;Parent=0123456789abcdef;Sampled=1
@@ -253,60 +297,19 @@ def extract_context_from_sqs_or_sns_event_or_context(event, lambda_context):
253297
logger.debug(
254298
"Found dd-trace injected trace context from AWSTraceHeader"
255299
)
256-
return Context(
257-
trace_id=int(trace_id_parts[2][8:], 16),
258-
span_id=int(x_ray_context["parent_id"], 16),
259-
sampling_priority=float(x_ray_context["sampled"]),
300+
return (
301+
Context(
302+
trace_id=int(trace_id_parts[2][8:], 16),
303+
span_id=int(x_ray_context["parent_id"], 16),
304+
sampling_priority=float(x_ray_context["sampled"]),
305+
),
306+
None,
307+
None,
260308
)
261-
return extract_context_from_lambda_context(lambda_context)
309+
return extract_context_from_lambda_context(lambda_context), None, None
262310
except Exception as e:
263311
logger.debug("The trace extractor returned with error %s", e)
264-
return extract_context_from_lambda_context(lambda_context)
265-
266-
267-
def extract_dd_context_from_sqs_or_sns_event(event):
268-
first_record = event.get("Records")[0]
269-
arn = first_record.get("eventSourceARN", "")
270-
is_sqs = bool(arn)
271-
272-
# logic to deal with SNS => SQS event
273-
if "body" in first_record:
274-
body_str = first_record.get("body")
275-
try:
276-
body = json.loads(body_str)
277-
if body.get("Type", "") == "Notification" and "TopicArn" in body:
278-
logger.debug("Found SNS message inside SQS event")
279-
first_record = get_first_record(create_sns_event(body))
280-
except Exception:
281-
pass
282-
283-
msg_attributes = first_record.get("messageAttributes")
284-
if msg_attributes is None:
285-
sns_record = first_record.get("Sns") or {}
286-
if not is_sqs:
287-
arn = sns_record.get("TopicArn", "")
288-
msg_attributes = sns_record.get("MessageAttributes") or {}
289-
dd_payload = msg_attributes.get("_datadog")
290-
if dd_payload:
291-
# SQS uses dataType and binaryValue/stringValue
292-
# SNS uses Type and Value
293-
dd_json_data = None
294-
dd_json_data_type = dd_payload.get("Type") or dd_payload.get("dataType")
295-
if dd_json_data_type == "Binary":
296-
import base64
297-
298-
dd_json_data = dd_payload.get("binaryValue") or dd_payload.get("Value")
299-
if dd_json_data:
300-
dd_json_data = base64.b64decode(dd_json_data)
301-
elif dd_json_data_type == "String":
302-
dd_json_data = dd_payload.get("stringValue") or dd_payload.get("Value")
303-
else:
304-
logger.debug(
305-
"Datadog Lambda Python only supports extracting trace"
306-
"context from String or Binary SQS/SNS message attributes"
307-
)
308-
return dd_json_data, arn
309-
return None, None
312+
return extract_context_from_lambda_context(lambda_context), None, None
310313

311314

312315
def _extract_context_from_eventbridge_sqs_event(event):
@@ -366,32 +369,26 @@ def extract_context_from_kinesis_event(event, lambda_context):
366369
Extract datadog trace context from a Kinesis Stream's base64 encoded data string
367370
"""
368371
try:
369-
dd_ctx, _ = extract_dd_context_from_kinesis_event(event, lambda_context)
372+
record = get_first_record(event)
373+
arn = record.get("eventSourceARN", "")
374+
kinesis = record.get("kinesis")
375+
if not kinesis:
376+
return extract_context_from_lambda_context(lambda_context)
377+
data = kinesis.get("data")
378+
if data:
379+
import base64
380+
381+
b64_bytes = data.encode("ascii")
382+
str_bytes = base64.b64decode(b64_bytes)
383+
data_str = str_bytes.decode("ascii")
384+
data_obj = json.loads(data_str)
385+
dd_ctx = data_obj.get("_datadog")
370386
if dd_ctx:
371-
return propagator.extract(dd_ctx)
387+
return propagator.extract(dd_ctx), dd_ctx, arn
372388
except Exception as e:
373389
logger.debug("The trace extractor returned with error %s", e)
374390

375-
return extract_context_from_lambda_context(lambda_context)
376-
377-
378-
def extract_dd_context_from_kinesis_event(event, lambda_context):
379-
record = get_first_record(event)
380-
arn = record.get("eventSourceARN", "")
381-
kinesis = record.get("kinesis")
382-
if not kinesis:
383-
return extract_context_from_lambda_context(lambda_context)
384-
data = kinesis.get("data")
385-
if data:
386-
import base64
387-
388-
b64_bytes = data.encode("ascii")
389-
str_bytes = base64.b64decode(b64_bytes)
390-
data_str = str_bytes.decode("ascii")
391-
data_obj = json.loads(data_str)
392-
dd_ctx = data_obj.get("_datadog")
393-
return dd_ctx, arn
394-
return None, None
391+
return extract_context_from_lambda_context(lambda_context), None, None
395392

396393

397394
def _deterministic_sha256_hash(s: str, part: str) -> int:
@@ -599,6 +596,9 @@ def extract_dd_trace_context(
599596
global dd_trace_context
600597
trace_context_source = None
601598
event_source = parse_event_source(event)
599+
context = None
600+
dd_json_data = None
601+
arn = None
602602

603603
if extractor is not None:
604604
context = extract_context_custom_extractor(extractor, event, lambda_context)
@@ -607,13 +607,15 @@ def extract_dd_trace_context(
607607
event, lambda_context, event_source, decode_authorizer_context
608608
)
609609
elif event_source.equals(EventTypes.SNS) or event_source.equals(EventTypes.SQS):
610-
context = extract_context_from_sqs_or_sns_event_or_context(
610+
context, dd_json_data, arn = extract_context_from_sqs_or_sns_event_or_context(
611611
event, lambda_context
612612
)
613613
elif event_source.equals(EventTypes.EVENTBRIDGE):
614614
context = extract_context_from_eventbridge_event(event, lambda_context)
615615
elif event_source.equals(EventTypes.KINESIS):
616-
context = extract_context_from_kinesis_event(event, lambda_context)
616+
context, dd_json_data, arn = extract_context_from_kinesis_event(
617+
event, lambda_context
618+
)
617619
elif event_source.equals(EventTypes.STEPFUNCTIONS):
618620
context = extract_context_from_step_functions(event, lambda_context)
619621
else:
@@ -630,7 +632,7 @@ def extract_dd_trace_context(
630632
if dd_trace_context:
631633
trace_context_source = TraceContextSource.XRAY
632634
logger.debug("extracted dd trace context %s", dd_trace_context)
633-
return dd_trace_context, trace_context_source, event_source
635+
return dd_trace_context, trace_context_source, event_source, dd_json_data, arn
634636

635637

636638
def get_dd_trace_context_obj():

datadog_lambda/wrapper.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,13 @@ def _before(self, event, context):
215215

216216
self.trigger_tags = extract_trigger_tags(event, context)
217217
# Extract Datadog trace context and source from incoming requests
218-
dd_context, trace_context_source, event_source = extract_dd_trace_context(
218+
(
219+
dd_context,
220+
trace_context_source,
221+
event_source,
222+
dd_json_data,
223+
arn,
224+
) = extract_dd_trace_context(
219225
event,
220226
context,
221227
extractor=self.trace_extractor,
@@ -240,31 +246,6 @@ def _before(self, event, context):
240246
event, context, event_source, config.decode_authorizer_context
241247
)
242248
if config.data_streams_enabled:
243-
from datadog_lambda.trigger import EventTypes
244-
245-
from datadog_lambda.tracing import (
246-
extract_dd_context_from_sqs_or_sns_event,
247-
extract_dd_context_from_kinesis_event,
248-
)
249-
250-
dd_json_data = None
251-
arn = None
252-
253-
try:
254-
if event_source.equals(EventTypes.SQS) or event_source.equals(
255-
EventTypes.SNS
256-
):
257-
(
258-
dd_json_data,
259-
arn,
260-
) = extract_dd_context_from_sqs_or_sns_event(event)
261-
elif event_source.equals(EventTypes.KINESIS):
262-
dd_json_data, arn = extract_dd_context_from_kinesis_event(
263-
event, context
264-
)
265-
except Exception as e:
266-
logger.debug(f"Failed to extract DSM checkpoint: {e}")
267-
268249
if dd_json_data:
269250
set_dsm_checkpoint(dd_json_data, event_source.to_string(), arn)
270251

@@ -377,8 +358,13 @@ def set_dsm_checkpoint(dd_json_data, event_source, arn):
377358

378359
try:
379360
if type(dd_json_data) is not dict:
380-
dd_json_data = json.loads(dd_json_data)
361+
logger.debug("Failed to set DSM checkpoint: context is not a dict")
362+
return
363+
381364
if "dd-pathway-ctx-base64" not in dd_json_data:
365+
logger.debug(
366+
"Failed to set DSM checkpoint: dd-pathway-ctx-base64 not found"
367+
)
382368
return
383369

384370
def create_carrier_get(dd_json_data):
@@ -388,7 +374,7 @@ def carrier_get(key):
388374
return carrier_get
389375

390376
carrier_get = create_carrier_get(dd_json_data)
391-
set_consume_checkpoint(event_source, arn, carrier_get, manual_checkpoint=False)
377+
set_consume_checkpoint(event_source, arn, carrier_get)
392378
except Exception as e:
393379
logger.debug(f"Failed to set DSM checkpoint: {e}")
394380

tests/test_tracing.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def test_extract_dd_trace_context(event, expect):
242242
event = json.load(f)
243243
ctx = get_mock_context()
244244

245-
actual, _, _ = extract_dd_trace_context(event, ctx)
245+
actual, _, _, _, _ = extract_dd_trace_context(event, ctx)
246246
assert (expect is None) is (actual is None)
247247
assert (expect is None) or actual.trace_id == expect.trace_id
248248
assert (expect is None) or actual.span_id == expect.span_id
@@ -266,7 +266,7 @@ def tearDown(self):
266266
@with_trace_propagation_style("datadog")
267267
def test_without_datadog_trace_headers(self):
268268
lambda_ctx = get_mock_context()
269-
ctx, source, event_source = extract_dd_trace_context({}, lambda_ctx)
269+
ctx, source, _, _, _ = extract_dd_trace_context({}, lambda_ctx)
270270
self.assertEqual(source, "xray")
271271
self.assertEqual(
272272
ctx,
@@ -289,7 +289,7 @@ def test_without_datadog_trace_headers(self):
289289
@with_trace_propagation_style("datadog")
290290
def test_with_non_object_event(self):
291291
lambda_ctx = get_mock_context()
292-
ctx, source, event_source = extract_dd_trace_context(b"", lambda_ctx)
292+
ctx, source, _, _, _ = extract_dd_trace_context(b"", lambda_ctx)
293293
self.assertEqual(source, "xray")
294294
self.assertEqual(
295295
ctx,
@@ -312,7 +312,7 @@ def test_with_non_object_event(self):
312312
@with_trace_propagation_style("datadog")
313313
def test_with_incomplete_datadog_trace_headers(self):
314314
lambda_ctx = get_mock_context()
315-
ctx, source, event_source = extract_dd_trace_context(
315+
ctx, source, _, _, _ = extract_dd_trace_context(
316316
{"headers": {TraceHeader.TRACE_ID: "123"}},
317317
lambda_ctx,
318318
)
@@ -337,7 +337,7 @@ def test_with_incomplete_datadog_trace_headers(self):
337337
def common_tests_with_trace_context_extraction_injection(
338338
self, headers, event_containing_headers, lambda_context=get_mock_context()
339339
):
340-
ctx, source, event_source = extract_dd_trace_context(
340+
ctx, source, _, _, _ = extract_dd_trace_context(
341341
event_containing_headers,
342342
lambda_context,
343343
)
@@ -390,7 +390,7 @@ def extractor_foo(event, context):
390390
return trace_id, parent_id, sampling_priority
391391

392392
lambda_ctx = get_mock_context()
393-
ctx, ctx_source, event_source = extract_dd_trace_context(
393+
ctx, ctx_source, _, _, _ = extract_dd_trace_context(
394394
{
395395
"foo": {
396396
TraceHeader.TRACE_ID: "123",
@@ -425,7 +425,7 @@ def extractor_raiser(event, context):
425425
raise Exception("kreator")
426426

427427
lambda_ctx = get_mock_context()
428-
ctx, ctx_source, event_source = extract_dd_trace_context(
428+
ctx, ctx_source, _, _, _ = extract_dd_trace_context(
429429
{
430430
"foo": {
431431
TraceHeader.TRACE_ID: "123",
@@ -624,7 +624,7 @@ def _test_step_function_trace_data_common(
624624
TraceHeader.TAGS: f"_dd.p.tid={expected_tid}",
625625
}
626626

627-
ctx, source, _ = extract_dd_trace_context(event, lambda_ctx)
627+
ctx, source, _, _, _ = extract_dd_trace_context(event, lambda_ctx)
628628

629629
self.assertEqual(source, "event")
630630
self.assertEqual(ctx, expected_context)
@@ -1145,7 +1145,7 @@ def test_mixed_parent_context_when_merging(self):
11451145
# use the dd-trace trace-id and the x-ray parent-id
11461146
# This allows parenting relationships like dd-trace -> x-ray -> dd-trace
11471147
lambda_ctx = get_mock_context()
1148-
ctx, source, event_type = extract_dd_trace_context(
1148+
ctx, source, _, _, _ = extract_dd_trace_context(
11491149
{
11501150
"headers": {
11511151
TraceHeader.TRACE_ID: "123",
@@ -1171,7 +1171,7 @@ def test_set_dd_trace_py_root_no_span_id(self):
11711171
os.environ["_X_AMZN_TRACE_ID"] = "Root=1-5e272390-8c398be037738dc042009320"
11721172

11731173
lambda_ctx = get_mock_context()
1174-
ctx, source, event_type = extract_dd_trace_context(
1174+
ctx, source, _, _, _ = extract_dd_trace_context(
11751175
{
11761176
"headers": {
11771177
TraceHeader.TRACE_ID: "123",

0 commit comments

Comments
 (0)