Skip to content

Commit 2e09255

Browse files
add tests
1 parent 0867597 commit 2e09255

File tree

1 file changed

+253
-0
lines changed

1 file changed

+253
-0
lines changed

tests/test_dsm.py

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import json
2+
import unittest
3+
import base64
4+
import os
5+
from unittest.mock import patch
6+
7+
from ddtrace.trace import Context
8+
9+
from datadog_lambda.tracing import (
10+
_extract_context,
11+
_create_carrier_get,
12+
extract_context_from_sqs_or_sns_event_or_context,
13+
extract_context_from_kinesis_event,
14+
)
15+
from tests.utils import get_mock_context
16+
17+
18+
class TestExtractContext(unittest.TestCase):
19+
def setUp(self):
20+
patcher = patch("datadog_lambda.tracing.propagator.extract")
21+
self.mock_extract = patcher.start()
22+
self.addCleanup(patcher.stop)
23+
24+
checkpoint_patcher = patch("ddtrace.data_streams.set_consume_checkpoint")
25+
self.mock_checkpoint = checkpoint_patcher.start()
26+
self.addCleanup(checkpoint_patcher.stop)
27+
28+
logger_patcher = patch("datadog_lambda.tracing.logger")
29+
self.mock_logger = logger_patcher.start()
30+
self.addCleanup(logger_patcher.stop)
31+
32+
def test_extract_context_data_streams_disabled(self):
33+
with patch.dict(os.environ, {'DD_DATA_STREAMS_ENABLED': 'false'}):
34+
context_json = {"dd-pathway-ctx-base64": "12345"}
35+
event_type = "sqs"
36+
arn = "arn:aws:sqs:us-east-1:123456789012:test-queue"
37+
38+
mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1)
39+
self.mock_extract.return_value = mock_context
40+
41+
result = _extract_context(context_json, event_type, arn)
42+
43+
self.mock_extract.assert_called_once_with(context_json)
44+
self.mock_checkpoint.assert_not_called()
45+
self.assertEqual(result, mock_context)
46+
47+
def test_extract_context_data_streams_enabled_complete_context(self):
48+
with patch.dict(os.environ, {'DD_DATA_STREAMS_ENABLED': 'true'}):
49+
context_json = {"dd-pathway-ctx-base64": "12345"}
50+
event_type = "sqs"
51+
arn = "arn:aws:sqs:us-east-1:123456789012:test-queue"
52+
53+
mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1)
54+
self.mock_extract.return_value = mock_context
55+
56+
result = _extract_context(context_json, event_type, arn)
57+
58+
self.mock_extract.assert_called_once_with(context_json)
59+
self.mock_checkpoint.assert_called_once()
60+
args, kwargs = self.mock_checkpoint.call_args
61+
self.assertEqual(args[0], event_type)
62+
self.assertEqual(args[1], arn)
63+
self.assertTrue(callable(args[2]))
64+
self.assertEqual(kwargs["manual_checkpoint"], False)
65+
self.assertEqual(result, mock_context)
66+
67+
def test_extract_context_data_streams_enabled_incomplete_context(self):
68+
with patch.dict(os.environ, {'DD_DATA_STREAMS_ENABLED': 'true'}):
69+
context_json = {"dd-pathway-ctx-base64": "12345"}
70+
event_type = "sqs"
71+
arn = "arn:aws:sqs:us-east-1:123456789012:test-queue"
72+
73+
mock_context = Context(trace_id=12345, span_id=None, sampling_priority=1)
74+
self.mock_extract.return_value = mock_context
75+
76+
result = _extract_context(context_json, event_type, arn)
77+
78+
self.mock_extract.assert_called_once_with(context_json)
79+
self.mock_checkpoint.assert_not_called()
80+
self.assertEqual(result, mock_context)
81+
82+
def test_extract_context_exception_path(self):
83+
with patch.dict(os.environ, {'DD_DATA_STREAMS_ENABLED': 'true'}):
84+
context_json = {"dd-pathway-ctx-base64": "12345"}
85+
event_type = "sqs"
86+
arn = "arn:aws:sqs:us-east-1:123456789012:test-queue"
87+
88+
mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1)
89+
self.mock_extract.return_value = mock_context
90+
91+
test_exception = Exception("Test exception")
92+
self.mock_checkpoint.side_effect = test_exception
93+
94+
result = _extract_context(context_json, event_type, arn)
95+
96+
self.mock_extract.assert_called_once_with(context_json)
97+
self.mock_checkpoint.assert_called_once()
98+
self.mock_logger.debug.assert_called_once()
99+
self.assertEqual(result, mock_context)
100+
101+
102+
class TestCreateCarrierGet(unittest.TestCase):
103+
def test_create_carrier_get_with_valid_data(self):
104+
context_json = {
105+
"x-datadog-trace-id": "12345",
106+
"x-datadog-parent-id": "67890",
107+
"x-datadog-sampling-priority": "1"
108+
}
109+
110+
carrier_get = _create_carrier_get(context_json)
111+
112+
self.assertTrue(callable(carrier_get))
113+
self.assertEqual(carrier_get("x-datadog-trace-id"), "12345")
114+
self.assertEqual(carrier_get("x-datadog-parent-id"), "67890")
115+
self.assertEqual(carrier_get("x-datadog-sampling-priority"), "1")
116+
117+
def test_create_carrier_get_with_missing_key(self):
118+
context_json = {"x-datadog-trace-id": "12345"}
119+
120+
carrier_get = _create_carrier_get(context_json)
121+
122+
self.assertTrue(callable(carrier_get))
123+
self.assertEqual(carrier_get("x-datadog-trace-id"), "12345")
124+
self.assertIsNone(carrier_get("x-datadog-parent-id"))
125+
126+
def test_create_carrier_get_with_empty_context(self):
127+
context_json = {}
128+
129+
carrier_get = _create_carrier_get(context_json)
130+
131+
self.assertTrue(callable(carrier_get))
132+
self.assertIsNone(carrier_get("any-key"))
133+
134+
135+
class TestExtractContextFromSqsOrSnsEvent(unittest.TestCase):
136+
def setUp(self):
137+
self.lambda_context = get_mock_context()
138+
139+
@patch("datadog_lambda.tracing._extract_context")
140+
def test_sqs_event_with_datadog_message_attributes(self, mock_extract_context):
141+
dd_data = {"dd-pathway-ctx-base64": "12345"}
142+
dd_json_data = json.dumps(dd_data)
143+
144+
event = {
145+
"Records": [{
146+
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue",
147+
"messageAttributes": {
148+
"_datadog": {
149+
"dataType": "String",
150+
"stringValue": dd_json_data
151+
}
152+
}
153+
}]
154+
}
155+
156+
mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1)
157+
mock_extract_context.return_value = mock_context
158+
159+
result = extract_context_from_sqs_or_sns_event_or_context(event, self.lambda_context)
160+
161+
mock_extract_context.assert_called_once_with(
162+
dd_data, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue"
163+
)
164+
self.assertEqual(result, mock_context)
165+
166+
@patch("datadog_lambda.tracing._extract_context")
167+
def test_sqs_event_with_binary_datadog_message_attributes(self, mock_extract_context):
168+
dd_data = {"dd-pathway-ctx-base64": "12345"}
169+
dd_json_data = json.dumps(dd_data)
170+
encoded_data = base64.b64encode(dd_json_data.encode()).decode()
171+
172+
event = {
173+
"Records": [{
174+
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:test-queue",
175+
"messageAttributes": {
176+
"_datadog": {
177+
"dataType": "Binary",
178+
"binaryValue": encoded_data
179+
}
180+
}
181+
}]
182+
}
183+
184+
mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1)
185+
mock_extract_context.return_value = mock_context
186+
187+
result = extract_context_from_sqs_or_sns_event_or_context(event, self.lambda_context)
188+
189+
mock_extract_context.assert_called_once_with(
190+
dd_data, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue"
191+
)
192+
self.assertEqual(result, mock_context)
193+
194+
@patch("datadog_lambda.tracing._extract_context")
195+
def test_sns_event_with_datadog_message_attributes(self, mock_extract_context):
196+
dd_data = {"dd-pathway-ctx-base64": "12345"}
197+
dd_json_data = json.dumps(dd_data)
198+
199+
event = {
200+
"Records": [{
201+
"eventSourceARN": "",
202+
"Sns": {
203+
"TopicArn": "arn:aws:sns:us-east-1:123456789012:test-topic",
204+
"MessageAttributes": {
205+
"_datadog": {
206+
"Type": "String",
207+
"Value": dd_json_data
208+
}
209+
}
210+
}
211+
}]
212+
}
213+
214+
mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1)
215+
mock_extract_context.return_value = mock_context
216+
217+
result = extract_context_from_sqs_or_sns_event_or_context(event, self.lambda_context)
218+
219+
mock_extract_context.assert_called_once_with(
220+
dd_data, "sns", "arn:aws:sns:us-east-1:123456789012:test-topic"
221+
)
222+
self.assertEqual(result, mock_context)
223+
224+
225+
class TestExtractContextFromKinesisEvent(unittest.TestCase):
226+
def setUp(self):
227+
self.lambda_context = get_mock_context()
228+
229+
@patch("datadog_lambda.tracing._extract_context")
230+
def test_kinesis_event_with_datadog_data(self, mock_extract_context):
231+
dd_data = {"dd-pathway-ctx-base64": "12345"}
232+
kinesis_data = {"_datadog": dd_data, "message": "test"}
233+
kinesis_data_str = json.dumps(kinesis_data)
234+
encoded_data = base64.b64encode(kinesis_data_str.encode()).decode()
235+
236+
event = {
237+
"Records": [{
238+
"eventSourceARN": "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream",
239+
"kinesis": {
240+
"data": encoded_data
241+
}
242+
}]
243+
}
244+
245+
mock_context = Context(trace_id=12345, span_id=67890, sampling_priority=1)
246+
mock_extract_context.return_value = mock_context
247+
248+
result = extract_context_from_kinesis_event(event, self.lambda_context)
249+
250+
mock_extract_context.assert_called_once_with(
251+
dd_data, "kinesis", "arn:aws:kinesis:us-east-1:123456789012:stream/test-stream"
252+
)
253+
self.assertEqual(result, mock_context)

0 commit comments

Comments
 (0)