Skip to content

Commit 2b520c0

Browse files
committed
Ensure all XRay Sampler functionality is under ParentBased logic
1 parent 831b76b commit 2b520c0

File tree

6 files changed

+198
-78
lines changed

6 files changed

+198
-78
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule_applier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget
1313
from opentelemetry.context import Context
1414
from opentelemetry.sdk.resources import Resource
15-
from opentelemetry.sdk.trace.sampling import Decision, ParentBased, Sampler, SamplingResult, TraceIdRatioBased
15+
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult, TraceIdRatioBased
1616
from opentelemetry.semconv.resource import CloudPlatformValues, ResourceAttributes
1717
from opentelemetry.semconv.trace import SpanAttributes
1818
from opentelemetry.trace import Link, SpanKind
@@ -42,7 +42,7 @@ def __init__(
4242
self.__borrowing = False
4343

4444
if target is None:
45-
self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(self.sampling_rule.FixedRate))
45+
self.__fixed_rate_sampler = TraceIdRatioBased(self.sampling_rule.FixedRate)
4646
# Until targets are fetched, initialize as borrowing=True if there will be a quota > 0
4747
if self.sampling_rule.ReservoirSize > 0:
4848
self.__reservoir_sampler = self.__create_reservoir_sampler(quota=1)
@@ -55,7 +55,7 @@ def __init__(
5555
new_quota = target.ReservoirQuota if target.ReservoirQuota is not None else 0
5656
new_fixed_rate = target.FixedRate if target.FixedRate is not None else 0
5757
self.__reservoir_sampler = self.__create_reservoir_sampler(quota=new_quota)
58-
self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(new_fixed_rate))
58+
self.__fixed_rate_sampler = TraceIdRatioBased(new_fixed_rate)
5959
if target.ReservoirQuotaTTL is not None:
6060
self.__reservoir_expiry = self._clock.from_timestamp(target.ReservoirQuotaTTL)
6161
else:
@@ -159,7 +159,7 @@ def matches(self, resource: Resource, attributes: Attributes) -> bool:
159159
)
160160

161161
def __create_reservoir_sampler(self, quota: int) -> Sampler:
162-
return ParentBased(_RateLimitingSampler(quota, self._clock))
162+
return _RateLimitingSampler(quota, self._clock)
163163

164164
# pylint: disable=no-self-use
165165
def __get_service_type(self, resource: Resource) -> str:

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,48 @@
2424
DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000"
2525

2626

27+
# Wrapper class to ensure that all XRay Sampler Functionality in _AwsXRayRemoteSampler
28+
# uses ParentBased logic to respect the parent span's sampling decision
2729
class AwsXRayRemoteSampler(Sampler):
30+
def __init__(
31+
self,
32+
resource: Resource,
33+
endpoint: str = None,
34+
polling_interval: int = None,
35+
log_level=None,
36+
):
37+
self._root = ParentBased(
38+
_AwsXRayRemoteSampler(
39+
resource=resource, endpoint=endpoint, polling_interval=polling_interval, log_level=log_level
40+
)
41+
)
42+
43+
# pylint: disable=no-self-use
44+
@override
45+
def should_sample(
46+
self,
47+
parent_context: Optional[Context],
48+
trace_id: int,
49+
name: str,
50+
kind: SpanKind = None,
51+
attributes: Attributes = None,
52+
links: Sequence[Link] = None,
53+
trace_state: TraceState = None,
54+
) -> SamplingResult:
55+
return self._root.should_sample(
56+
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
57+
)
58+
59+
# pylint: disable=no-self-use
60+
@override
61+
def get_description(self) -> str:
62+
return f"AwsXRayRemoteSampler{{root:{self._root.get_description()}}}"
63+
64+
65+
# _AwsXRayRemoteSampler contains all core XRay Sampler Functionality,
66+
# however it is NOT Parent-based (e.g. Sample logic runs for each span)
67+
# Not intended for external use, use Parent-based `AwsXRayRemoteSampler` instead.
68+
class _AwsXRayRemoteSampler(Sampler):
2869
"""
2970
Remote Sampler for OpenTelemetry that gets sampling configurations from AWS X-Ray
3071
@@ -58,7 +99,7 @@ def __init__(
5899
self.__client_id = self.__generate_client_id()
59100
self._clock = _Clock()
60101
self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level)
61-
self.__fallback_sampler = ParentBased(_FallbackSampler(self._clock))
102+
self.__fallback_sampler = _FallbackSampler(self._clock)
62103

63104
self.__polling_interval = polling_interval
64105
self.__target_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS
@@ -114,7 +155,7 @@ def should_sample(
114155
# pylint: disable=no-self-use
115156
@override
116157
def get_description(self) -> str:
117-
description = "AwsXRayRemoteSampler{remote sampling with AWS X-Ray}"
158+
description = "_AwsXRayRemoteSampler{remote sampling with AWS X-Ray}"
118159
return description
119160

120161
def __get_and_update_sampling_rules(self) -> None:

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py

Lines changed: 110 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
from mock_clock import MockClock
1212

13-
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
13+
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler, _AwsXRayRemoteSampler
1414
from opentelemetry.sdk.resources import Resource
15+
from opentelemetry.sdk.trace import Tracer, TracerProvider
1516
from opentelemetry.sdk.trace.sampling import Decision
1617

1718
TEST_DIR = os.path.dirname(os.path.realpath(__file__))
@@ -55,27 +56,31 @@ def setUp(self):
5556
def tearDown(self):
5657
# Clean up timers
5758
if self.rs is not None:
58-
self.rs._rules_timer.cancel()
59-
self.rs._targets_timer.cancel()
59+
self.rs._root._root._rules_timer.cancel()
60+
self.rs._root._root._targets_timer.cancel()
6061

6162
def test_create_remote_sampler_with_empty_resource(self):
6263
self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty())
63-
self.assertIsNotNone(self.rs._rules_timer)
64-
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 300)
65-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
66-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
67-
self.assertTrue(len(self.rs._AwsXRayRemoteSampler__client_id), 24)
64+
self.assertIsNotNone(self.rs._root._root._rules_timer)
65+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 300)
66+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
67+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
68+
self.assertTrue(len(self.rs._root._root._AwsXRayRemoteSampler__client_id), 24)
6869

6970
def test_create_remote_sampler_with_populated_resource(self):
7071
self.rs = AwsXRayRemoteSampler(
7172
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
7273
)
73-
self.assertIsNotNone(self.rs._rules_timer)
74-
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 300)
75-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
76-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
77-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
78-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
74+
self.assertIsNotNone(self.rs._root._root._rules_timer)
75+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 300)
76+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
77+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
78+
self.assertEqual(
79+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name"
80+
)
81+
self.assertEqual(
82+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform"
83+
)
7984

8085
def test_create_remote_sampler_with_all_fields_populated(self):
8186
self.rs = AwsXRayRemoteSampler(
@@ -84,35 +89,39 @@ def test_create_remote_sampler_with_all_fields_populated(self):
8489
polling_interval=120,
8590
log_level=DEBUG,
8691
)
87-
self.assertIsNotNone(self.rs._rules_timer)
88-
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 120)
89-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
90-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
92+
self.assertIsNotNone(self.rs._root._root._rules_timer)
93+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 120)
94+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
95+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
9196
self.assertEqual(
92-
self.rs._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint,
97+
self.rs._root._root._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint,
9398
"http://abc.com/GetSamplingRules",
9499
)
95-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
96-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
100+
self.assertEqual(
101+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name"
102+
)
103+
self.assertEqual(
104+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform"
105+
)
97106

98107
@patch("requests.Session.post", side_effect=mocked_requests_get)
99108
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 2)
100109
def test_update_sampling_rules_and_targets_with_pollers_and_should_sample(self, mock_post=None):
101110
self.rs = AwsXRayRemoteSampler(
102111
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
103112
)
104-
self.assertEqual(self.rs._AwsXRayRemoteSampler__target_polling_interval, 2)
113+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval, 2)
105114

106115
time.sleep(1.0)
107116
self.assertEqual(
108-
self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName,
117+
self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName,
109118
"test",
110119
)
111120
self.assertEqual(self.rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.DROP)
112121

113122
# wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s)
114123
time.sleep(2.0)
115-
self.assertEqual(self.rs._AwsXRayRemoteSampler__target_polling_interval, 1000)
124+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval, 1000)
116125
self.assertEqual(
117126
self.rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision,
118127
Decision.RECORD_AND_SAMPLE,
@@ -162,9 +171,9 @@ def test_multithreading_with_large_reservoir_with_otel_sdk(self, mock_post=None)
162171
threads[idx].join()
163172
sum_sampled += sampled_array[idx]
164173

165-
test_rule_applier = self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0]
174+
test_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0]
166175
self.assertEqual(
167-
test_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
176+
test_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
168177
100000,
169178
)
170179
self.assertEqual(sum_sampled, 100000)
@@ -183,19 +192,19 @@ def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None):
183192
# which will eat up more than 1 second of reservoir. Using MockClock we can freeze time
184193
# and pretend all thread jobs start and end at the exact same time,
185194
# assume and test exactly 1 second of reservoir (100 quota) only
186-
mock_clock: MockClock = self.rs._clock
195+
mock_clock: MockClock = self.rs._root._root._clock
187196

188197
time.sleep(1.0)
189198
mock_clock.add_time(1.0)
190-
self.assertEqual(mock_clock.now(), self.rs._clock.now())
199+
self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now())
191200
self.assertEqual(
192201
self.rs.should_sample(None, 0, "name", attributes=attributes).decision, Decision.RECORD_AND_SAMPLE
193202
)
194203

195204
# wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s)
196205
time.sleep(2.0)
197206
mock_clock.add_time(2.0)
198-
self.assertEqual(mock_clock.now(), self.rs._clock.now())
207+
self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now())
199208

200209
number_of_spans = 100
201210
thread_count = 1000
@@ -219,9 +228,79 @@ def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None):
219228
threads[idx].join()
220229
sum_sampled += sampled_array[idx]
221230

222-
default_rule_applier = self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
231+
default_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
223232
self.assertEqual(
224-
default_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
233+
default_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
225234
100,
226235
)
227236
self.assertEqual(sum_sampled, 100)
237+
238+
def test_get_description(self) -> str:
239+
self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler(resource=Resource.create({"service.name": "dummy_name"}))
240+
self.assertEqual(
241+
self.rs.get_description(),
242+
"AwsXRayRemoteSampler{root:ParentBased{root:_AwsXRayRemoteSampler{remote sampling with AWS X-Ray},remoteParentSampled:AlwaysOnSampler,remoteParentNotSampled:AlwaysOffSampler,localParentSampled:AlwaysOnSampler,localParentNotSampled:AlwaysOffSampler}}", # noqa: E501
243+
)
244+
245+
@patch("requests.Session.post", side_effect=mocked_requests_get)
246+
def test_parent_based_xray_sampler_creates_updates_statistics_once_for_one_parent_span_with_two_children(
247+
self, mock_post=None
248+
):
249+
self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler(
250+
resource=Resource.create({"service.name": "use-default-sample-all-rule"})
251+
)
252+
time.sleep(1.0)
253+
254+
provider = TracerProvider(sampler=self.rs)
255+
tracer: Tracer = provider.get_tracer("test_tracer_1")
256+
257+
# child1 and child2 are child spans of root parent0
258+
# For AwsXRayRemoteSampler (ParentBased), expect only parent0 to update statistics
259+
with tracer.start_as_current_span("parent0") as _:
260+
with tracer.start_as_current_span("child1") as _:
261+
pass
262+
with tracer.start_as_current_span("child2") as _:
263+
pass
264+
default_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
265+
self.assertEqual(
266+
default_rule_applier._SamplingRuleApplier__statistics.RequestCount,
267+
1,
268+
)
269+
self.assertEqual(
270+
default_rule_applier._SamplingRuleApplier__statistics.SampleCount,
271+
1,
272+
)
273+
274+
@patch("requests.Session.post", side_effect=mocked_requests_get)
275+
def test_non_parent_based_xray_sampler_creates_updates_statistics_thrice_for_one_parent_span_with_two_children(
276+
self, mock_post=None
277+
):
278+
non_parent_based_xray_sampler: _AwsXRayRemoteSampler = _AwsXRayRemoteSampler(
279+
resource=Resource.create({"service.name": "use-default-sample-all-rule"})
280+
)
281+
time.sleep(1.0)
282+
283+
provider = TracerProvider(sampler=non_parent_based_xray_sampler)
284+
tracer: Tracer = provider.get_tracer("test_tracer_2")
285+
286+
# child1 and child2 are child spans of root parent0
287+
# For _AwsXRayRemoteSampler (Non-ParentBased), expect all 3 spans to update statistics
288+
with tracer.start_as_current_span("parent0") as _:
289+
with tracer.start_as_current_span("child1") as _:
290+
pass
291+
with tracer.start_as_current_span("child2") as _:
292+
pass
293+
default_rule_applier = (
294+
non_parent_based_xray_sampler._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
295+
)
296+
self.assertEqual(
297+
default_rule_applier._SamplingRuleApplier__statistics.RequestCount,
298+
3,
299+
)
300+
self.assertEqual(
301+
default_rule_applier._SamplingRuleApplier__statistics.SampleCount,
302+
3,
303+
)
304+
305+
non_parent_based_xray_sampler._rules_timer.cancel()
306+
non_parent_based_xray_sampler._targets_timer.cancel()

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,15 @@ def test_update_sampling_targets(self):
136136
# quota should be 1 because of borrowing=true until targets are updated
137137
rule_applier_0 = rule_cache._RuleCache__rule_appliers[0]
138138
self.assertEqual(
139-
rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1
139+
rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, 1
140140
)
141-
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_2.FixedRate)
141+
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, sampling_rule_2.FixedRate)
142142

143143
rule_applier_1 = rule_cache._RuleCache__rule_appliers[1]
144144
self.assertEqual(
145-
rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1
145+
rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, 1
146146
)
147-
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_1.FixedRate)
147+
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, sampling_rule_1.FixedRate)
148148

149149
target_1 = {
150150
"FixedRate": 0.05,
@@ -179,17 +179,17 @@ def test_update_sampling_targets(self):
179179
# borrowing=false, use quota from targets
180180
rule_applier_0 = rule_cache._RuleCache__rule_appliers[0]
181181
self.assertEqual(
182-
rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
182+
rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
183183
target_2["ReservoirQuota"],
184184
)
185-
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_2["FixedRate"])
185+
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, target_2["FixedRate"])
186186

187187
rule_applier_1 = rule_cache._RuleCache__rule_appliers[1]
188188
self.assertEqual(
189-
rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
189+
rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
190190
target_1["ReservoirQuota"],
191191
)
192-
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_1["FixedRate"])
192+
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, target_1["FixedRate"])
193193

194194
# Test target response modified after Rule cache's last modified date
195195
target_response.LastRuleModification = mock_clock.now().timestamp() + 1

0 commit comments

Comments
 (0)