@@ -51,6 +51,8 @@ def get_child_parent_new_carrier(old_carrier):
51
51
52
52
53
53
class TestB3Format (unittest .TestCase ):
54
+ # pylint: disable=too-many-public-methods
55
+
54
56
@classmethod
55
57
def setUpClass (cls ):
56
58
generator = id_generator .RandomIdGenerator ()
@@ -215,6 +217,31 @@ def test_flags_and_sampling(self):
215
217
216
218
self .assertEqual (new_carrier [FORMAT .SAMPLED_KEY ], "1" )
217
219
220
+ def test_derived_ctx_is_returned_for_success (self ):
221
+ """Ensure returned context is derived from the given context."""
222
+ old_ctx = {"k1" : "v1" }
223
+ new_ctx = FORMAT .extract (
224
+ {
225
+ FORMAT .TRACE_ID_KEY : self .serialized_trace_id ,
226
+ FORMAT .SPAN_ID_KEY : self .serialized_span_id ,
227
+ FORMAT .FLAGS_KEY : "1" ,
228
+ },
229
+ old_ctx ,
230
+ )
231
+ self .assertIn ("current-span" , new_ctx )
232
+ for key , value in old_ctx .items ():
233
+ self .assertIn (key , new_ctx )
234
+ self .assertEqual (new_ctx [key ], value )
235
+
236
+ def test_derived_ctx_is_returned_for_failure (self ):
237
+ """Ensure returned context is derived from the given context."""
238
+ old_ctx = {"k2" : "v2" }
239
+ new_ctx = FORMAT .extract ({}, old_ctx )
240
+ self .assertNotIn ("current-span" , new_ctx )
241
+ for key , value in old_ctx .items ():
242
+ self .assertIn (key , new_ctx )
243
+ self .assertEqual (new_ctx [key ], value )
244
+
218
245
def test_64bit_trace_id (self ):
219
246
"""64 bit trace ids should be padded to 128 bit trace ids."""
220
247
trace_id_64_bit = self .serialized_trace_id [:16 ]
@@ -334,3 +361,12 @@ def test_fields(self):
334
361
inject_fields .add (call [1 ][1 ])
335
362
336
363
self .assertEqual (FORMAT .fields , inject_fields )
364
+
365
+ def test_extract_none_context (self ):
366
+ """Given no trace ID, do not modify context"""
367
+ old_ctx = None
368
+
369
+ carrier = {}
370
+ new_ctx = FORMAT .extract (carrier , old_ctx )
371
+ self .assertIsNotNone (new_ctx )
372
+ self .assertEqual (new_ctx ["current-span" ], trace_api .INVALID_SPAN )
0 commit comments