|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import unittest
|
16 |
| -from unittest.mock import Mock, patch |
| 16 | +from unittest.mock import Mock |
17 | 17 |
|
18 | 18 | import opentelemetry.propagators.b3 as b3_format # pylint: disable=no-name-in-module,import-error
|
19 | 19 | import opentelemetry.sdk.trace as trace
|
@@ -231,89 +231,73 @@ def test_64bit_trace_id(self):
|
231 | 231 | new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit
|
232 | 232 | )
|
233 | 233 |
|
234 |
| - def test_invalid_single_header(self): |
235 |
| - """If an invalid single header is passed, return an |
236 |
| - invalid SpanContext. |
237 |
| - """ |
| 234 | + def test_extract_invalid_single_header(self): |
| 235 | + """Given unparsable header, do not modify context""" |
| 236 | + old_ctx = {} |
| 237 | + |
238 | 238 | carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"}
|
239 |
| - ctx = FORMAT.extract(carrier) |
240 |
| - span_context = trace_api.get_current_span(ctx).get_span_context() |
241 |
| - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) |
242 |
| - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) |
| 239 | + new_ctx = FORMAT.extract(carrier, old_ctx) |
| 240 | + |
| 241 | + self.assertDictEqual(new_ctx, old_ctx) |
| 242 | + |
| 243 | + def test_extract_missing_trace_id(self): |
| 244 | + """Given no trace ID, do not modify context""" |
| 245 | + old_ctx = {} |
243 | 246 |
|
244 |
| - def test_missing_trace_id(self): |
245 |
| - """If a trace id is missing, populate an invalid trace id.""" |
246 | 247 | carrier = {
|
247 | 248 | FORMAT.SPAN_ID_KEY: self.serialized_span_id,
|
248 | 249 | FORMAT.FLAGS_KEY: "1",
|
249 | 250 | }
|
| 251 | + new_ctx = FORMAT.extract(carrier, old_ctx) |
250 | 252 |
|
251 |
| - ctx = FORMAT.extract(carrier) |
252 |
| - span_context = trace_api.get_current_span(ctx).get_span_context() |
253 |
| - self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) |
254 |
| - |
255 |
| - @patch( |
256 |
| - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_trace_id" |
257 |
| - ) |
258 |
| - @patch( |
259 |
| - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_span_id" |
260 |
| - ) |
261 |
| - def test_invalid_trace_id( |
262 |
| - self, mock_generate_span_id, mock_generate_trace_id |
263 |
| - ): |
264 |
| - """If a trace id is invalid, generate a trace id.""" |
| 253 | + self.assertDictEqual(new_ctx, old_ctx) |
265 | 254 |
|
266 |
| - mock_generate_trace_id.configure_mock(return_value=1) |
267 |
| - mock_generate_span_id.configure_mock(return_value=2) |
| 255 | + def test_extract_invalid_trace_id(self): |
| 256 | + """Given invalid trace ID, do not modify context""" |
| 257 | + old_ctx = {} |
268 | 258 |
|
269 | 259 | carrier = {
|
270 | 260 | FORMAT.TRACE_ID_KEY: "abc123",
|
271 | 261 | FORMAT.SPAN_ID_KEY: self.serialized_span_id,
|
272 | 262 | FORMAT.FLAGS_KEY: "1",
|
273 | 263 | }
|
| 264 | + new_ctx = FORMAT.extract(carrier, old_ctx) |
274 | 265 |
|
275 |
| - ctx = FORMAT.extract(carrier) |
276 |
| - span_context = trace_api.get_current_span(ctx).get_span_context() |
| 266 | + self.assertDictEqual(new_ctx, old_ctx) |
277 | 267 |
|
278 |
| - self.assertEqual(span_context.trace_id, 1) |
279 |
| - self.assertEqual(span_context.span_id, 2) |
280 |
| - |
281 |
| - @patch( |
282 |
| - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_trace_id" |
283 |
| - ) |
284 |
| - @patch( |
285 |
| - "opentelemetry.sdk.trace.id_generator.RandomIdGenerator.generate_span_id" |
286 |
| - ) |
287 |
| - def test_invalid_span_id( |
288 |
| - self, mock_generate_span_id, mock_generate_trace_id |
289 |
| - ): |
290 |
| - """If a span id is invalid, generate a trace id.""" |
291 |
| - |
292 |
| - mock_generate_trace_id.configure_mock(return_value=1) |
293 |
| - mock_generate_span_id.configure_mock(return_value=2) |
| 268 | + def test_extract_invalid_span_id(self): |
| 269 | + """Given invalid span ID, do not modify context""" |
| 270 | + old_ctx = {} |
294 | 271 |
|
295 | 272 | carrier = {
|
296 | 273 | FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
|
297 | 274 | FORMAT.SPAN_ID_KEY: "abc123",
|
298 | 275 | FORMAT.FLAGS_KEY: "1",
|
299 | 276 | }
|
| 277 | + new_ctx = FORMAT.extract(carrier, old_ctx) |
300 | 278 |
|
301 |
| - ctx = FORMAT.extract(carrier) |
302 |
| - span_context = trace_api.get_current_span(ctx).get_span_context() |
| 279 | + self.assertDictEqual(new_ctx, old_ctx) |
303 | 280 |
|
304 |
| - self.assertEqual(span_context.trace_id, 1) |
305 |
| - self.assertEqual(span_context.span_id, 2) |
| 281 | + def test_extract_missing_span_id(self): |
| 282 | + """Given no span ID, do not modify context""" |
| 283 | + old_ctx = {} |
306 | 284 |
|
307 |
| - def test_missing_span_id(self): |
308 |
| - """If a trace id is missing, populate an invalid trace id.""" |
309 | 285 | carrier = {
|
310 | 286 | FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
|
311 | 287 | FORMAT.FLAGS_KEY: "1",
|
312 | 288 | }
|
| 289 | + new_ctx = FORMAT.extract(carrier, old_ctx) |
| 290 | + |
| 291 | + self.assertDictEqual(new_ctx, old_ctx) |
| 292 | + |
| 293 | + def test_extract_empty_carrier(self): |
| 294 | + """Given no headers at all, do not modify context""" |
| 295 | + old_ctx = {} |
| 296 | + |
| 297 | + carrier = {} |
| 298 | + new_ctx = FORMAT.extract(carrier, old_ctx) |
313 | 299 |
|
314 |
| - ctx = FORMAT.extract(carrier) |
315 |
| - span_context = trace_api.get_current_span(ctx).get_span_context() |
316 |
| - self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) |
| 300 | + self.assertDictEqual(new_ctx, old_ctx) |
317 | 301 |
|
318 | 302 | @staticmethod
|
319 | 303 | def test_inject_empty_context():
|
|
0 commit comments