11import platform
2+ import time
23
34import pytest
45
6+ from aws_xray_sdk .core .sampling .sampling_rule import SamplingRule
7+ from aws_xray_sdk .core .sampling .rule_cache import RuleCache
8+ from aws_xray_sdk .core .sampling .sampler import DefaultSampler
59from aws_xray_sdk .version import VERSION
610from .util import get_new_stubbed_recorder
711
@@ -38,7 +42,6 @@ def test_default_runtime_context():
3842
3943
4044def test_subsegment_parenting ():
41-
4245 segment = xray_recorder .begin_segment ('name' )
4346 subsegment = xray_recorder .begin_subsegment ('name' )
4447 xray_recorder .end_subsegment ('name' )
@@ -97,7 +100,6 @@ def test_put_annotation_metadata():
97100
98101
99102def test_pass_through_with_missing_context ():
100-
101103 xray_recorder = get_new_stubbed_recorder ()
102104 xray_recorder .configure (sampling = False , context_missing = 'LOG_ERROR' )
103105 assert not xray_recorder .is_sampled ()
@@ -175,7 +177,6 @@ def test_in_segment_exception():
175177 assert segment .fault is True
176178 assert len (segment .cause ['exceptions' ]) == 1
177179
178-
179180 with pytest .raises (Exception ):
180181 with xray_recorder .in_segment ('name' ) as segment :
181182 with xray_recorder .in_subsegment ('name' ) as subsegment :
@@ -259,7 +260,6 @@ def test_disabled_get_context_entity():
259260 assert type (entity ) is DummySegment
260261
261262
262-
263263def test_max_stack_trace_zero ():
264264 xray_recorder .configure (max_trace_back = 1 )
265265 with pytest .raises (Exception ):
@@ -279,3 +279,41 @@ def test_max_stack_trace_zero():
279279
280280 assert len (segment_with_stack .cause ['exceptions' ][0 ].stack ) == 1
281281 assert len (segment_no_stack .cause ['exceptions' ][0 ].stack ) == 0
282+
283+
284+ # CustomSampler to mimic the DefaultSampler,
285+ # but without the rule and target polling logic.
286+ class CustomSampler (DefaultSampler ):
287+ def start (self ):
288+ pass
289+
290+ def should_trace (self , sampling_req = None ):
291+ rule_cache = RuleCache ()
292+ rule_cache .last_updated = int (time .time ())
293+ sampling_rule_a = SamplingRule (name = 'rule_a' ,
294+ priority = 2 ,
295+ rate = 0.5 ,
296+ reservoir_size = 1 ,
297+ service = 'app_a' )
298+ sampling_rule_b = SamplingRule (name = 'rule_b' ,
299+ priority = 2 ,
300+ rate = 0.5 ,
301+ reservoir_size = 1 ,
302+ service = 'app_b' )
303+ rule_cache .load_rules ([sampling_rule_a , sampling_rule_b ])
304+ now = int (time .time ())
305+ if sampling_req and not sampling_req .get ('service_type' , None ):
306+ sampling_req ['service_type' ] = self ._origin
307+ elif sampling_req is None :
308+ sampling_req = {'service_type' : self ._origin }
309+ matched_rule = rule_cache .get_matched_rule (sampling_req , now )
310+ if matched_rule :
311+ return self ._process_matched_rule (matched_rule , now )
312+ else :
313+ return self ._local_sampler .should_trace (sampling_req )
314+
315+
316+ def test_begin_segment_matches_sampling_rule_on_name ():
317+ xray_recorder .configure (sampling = True , sampler = CustomSampler ())
318+ segment = xray_recorder .begin_segment ("app_b" )
319+ assert segment .aws .get ('xray' ).get ('sampling_rule_name' ) == 'rule_b'
0 commit comments