@@ -237,7 +237,6 @@ def extract_context_from_sqs_or_sns_event_or_context(
237237 Falls back to lambda context if no trace data is found in the SQS message attributes.
238238 Set a DSM checkpoint if DSM is enabled and the method for context propagation is supported.
239239 """
240- source_arn = ""
241240 event_type = "sqs" if event_source .equals (EventTypes .SQS ) else "sns"
242241
243242 # EventBridge => SQS
@@ -248,100 +247,107 @@ def extract_context_from_sqs_or_sns_event_or_context(
248247 except Exception :
249248 logger .debug ("Failed extracting context as EventBridge to SQS." )
250249
251- context = None
250+ apm_context = None
252251 for idx , record in enumerate (event .get ("Records" , [])):
252+ source_arn = (
253+ record .get ("eventSourceARN" )
254+ if event_type == "sqs"
255+ else record .get ("Sns" , {}).get ("TopicArn" )
256+ )
257+ dd_data = None
253258 try :
254- source_arn = record .get ("eventSourceARN" , "" )
255- dd_ctx = None
256-
257- # logic to deal with SNS => SQS event
258- if "body" in record :
259- body_str = record .get ("body" )
260- try :
261- body = json .loads (body_str )
262- if body .get ("Type" , "" ) == "Notification" and "TopicArn" in body :
263- logger .debug ("Found SNS message inside SQS event" )
264- record = get_first_record (create_sns_event (body ))
265- except Exception :
266- pass
267-
268- msg_attributes = record .get ("messageAttributes" )
269- if msg_attributes is None :
270- sns_record = record .get ("Sns" ) or {}
271- # SNS->SQS event would extract SNS arn without this check
272- if event_source .equals (EventTypes .SNS ):
273- source_arn = sns_record .get ("TopicArn" , "" )
274- msg_attributes = sns_record .get ("MessageAttributes" ) or {}
275- dd_payload = msg_attributes .get ("_datadog" )
276- if dd_payload :
277- # SQS uses dataType and binaryValue/stringValue
278- # SNS uses Type and Value
279- # fmt: off
280- dd_json_data = None
281- dd_json_data_type = dd_payload .get ("Type" ) or dd_payload .get ("dataType" )
282- if dd_json_data_type == "Binary" :
283- import base64
284-
285- dd_json_data = dd_payload .get ("binaryValue" ) or dd_payload .get ("Value" )
286- if dd_json_data :
287- dd_json_data = base64 .b64decode (dd_json_data )
288- elif dd_json_data_type == "String" :
289- dd_json_data = dd_payload .get ("stringValue" ) or dd_payload .get ("Value" )
290- # fmt: on
259+ dd_data = _extract_context_from_sqs_or_sns_record (record )
260+ if idx == 0 :
261+ if dd_data and is_step_function_event (dd_data ):
262+ try :
263+ return extract_context_from_step_functions (dd_data , None )
264+ except Exception :
265+ logger .debug (
266+ "Failed to extract Step Functions context from SQS/SNS event."
267+ )
268+ elif not dd_data :
269+ apm_context = _extract_context_from_xray (record )
291270 else :
292- logger .debug (
293- "Datadog Lambda Python only supports extracting trace"
294- "context from String or Binary SQS/SNS message attributes"
295- )
296-
297- if dd_json_data :
298- dd_data = json .loads (dd_json_data )
299-
300- if idx == 0 :
301- if is_step_function_event (dd_data ):
302- try :
303- return extract_context_from_step_functions (
304- dd_data , None
305- )
306- except Exception :
307- logger .debug (
308- "Failed to extract Step Functions context from SQS/SNS event."
309- )
310- context = propagator .extract (dd_data )
311- if not config .data_streams_enabled :
312- break
313- dd_ctx = dd_data
314- elif idx == 0 :
315- # Handle case where trace context is injected into attributes.AWSTraceHeader
316- # example:Root=1-654321ab-000000001234567890abcdef;Parent=0123456789abcdef;Sampled=1
317- attrs = event .get ("Records" )[0 ].get ("attributes" )
318- if attrs :
319- x_ray_header = attrs .get ("AWSTraceHeader" )
320- if x_ray_header :
321- x_ray_context = parse_xray_header (x_ray_header )
322- trace_id_parts = x_ray_context .get ("trace_id" , "" ).split ("-" )
323- if len (trace_id_parts ) > 2 and trace_id_parts [2 ].startswith (
324- DD_TRACE_JAVA_TRACE_ID_PADDING
325- ):
326- # If it starts with eight 0's padding,
327- # then this AWSTraceHeader contains Datadog injected trace context
328- logger .debug (
329- "Found dd-trace injected trace context from AWSTraceHeader"
330- )
331- context = Context (
332- trace_id = int (trace_id_parts [2 ][8 :], 16 ),
333- span_id = int (x_ray_context ["parent_id" ], 16 ),
334- sampling_priority = float (x_ray_context ["sampled" ]),
335- )
336- if not config .data_streams_enabled :
337- break
271+ apm_context = propagator .extract (dd_data )
338272 except Exception as e :
339273 logger .debug ("The trace extractor returned with error %s" , e )
274+ if config .data_streams_enabled :
275+ _dsm_set_checkpoint (dd_data , event_type , source_arn )
276+ if not config .data_streams_enabled :
277+ break
278+
279+ return (
280+ apm_context
281+ if apm_context
282+ else extract_context_from_lambda_context (lambda_context )
283+ )
340284
341- # Set DSM checkpoint once per record
342- _dsm_set_checkpoint (dd_ctx , event_type , source_arn )
343285
344- return context if context else extract_context_from_lambda_context (lambda_context )
286+ def _extract_context_from_sqs_or_sns_record (record ):
287+ # logic to deal with SNS => SQS event
288+ if "body" in record :
289+ body_str = record .get ("body" )
290+ try :
291+ body = json .loads (body_str )
292+ if body .get ("Type" , "" ) == "Notification" and "TopicArn" in body :
293+ logger .debug ("Found SNS message inside SQS event" )
294+ record = get_first_record (create_sns_event (body ))
295+ except Exception :
296+ pass
297+
298+ msg_attributes = record .get ("messageAttributes" )
299+ if msg_attributes is None :
300+ sns_record = record .get ("Sns" ) or {}
301+ msg_attributes = sns_record .get ("MessageAttributes" ) or {}
302+ dd_payload = msg_attributes .get ("_datadog" )
303+ if dd_payload :
304+ # SQS uses dataType and binaryValue/stringValue
305+ # SNS uses Type and Value
306+ # fmt: off
307+ dd_json_data = None
308+ dd_json_data_type = dd_payload .get ("Type" ) or dd_payload .get ("dataType" )
309+ if dd_json_data_type == "Binary" :
310+ import base64
311+
312+ dd_json_data = dd_payload .get ("binaryValue" ) or dd_payload .get ("Value" )
313+ if dd_json_data :
314+ dd_json_data = base64 .b64decode (dd_json_data )
315+ elif dd_json_data_type == "String" :
316+ dd_json_data = dd_payload .get ("stringValue" ) or dd_payload .get ("Value" )
317+ # fmt: on
318+ else :
319+ logger .debug (
320+ "Datadog Lambda Python only supports extracting trace"
321+ "context from String or Binary SQS/SNS message attributes"
322+ )
323+
324+ if dd_json_data :
325+ dd_data = json .loads (dd_json_data )
326+ return dd_data
327+ return None
328+
329+
330+ def _extract_context_from_xray (record ):
331+ attrs = record .get ("attributes" )
332+ if attrs :
333+ x_ray_header = attrs .get ("AWSTraceHeader" )
334+ if x_ray_header :
335+ x_ray_context = parse_xray_header (x_ray_header )
336+ trace_id_parts = x_ray_context .get ("trace_id" , "" ).split ("-" )
337+ if len (trace_id_parts ) > 2 and trace_id_parts [2 ].startswith (
338+ DD_TRACE_JAVA_TRACE_ID_PADDING
339+ ):
340+ # If it starts with eight 0's padding,
341+ # then this AWSTraceHeader contains Datadog injected trace context
342+ logger .debug (
343+ "Found dd-trace injected trace context from AWSTraceHeader"
344+ )
345+ return Context (
346+ trace_id = int (trace_id_parts [2 ][8 :], 16 ),
347+ span_id = int (x_ray_context ["parent_id" ], 16 ),
348+ sampling_priority = float (x_ray_context ["sampled" ]),
349+ )
350+ return None
345351
346352
347353def _extract_context_from_eventbridge_sqs_event (event ):
0 commit comments