Skip to content

Commit 0a936d5

Browse files
jj22eepxaws
andauthored
[XRay Sampler] Fix - Ensure all XRay Sampler functionality is under ParentBased logic (#269)
### Issue #, if available: This PR fixes aws-observability/aws-otel-js-instrumentation#79, but in Python This fixes the bug where: - While the subcomponents of XRay Sampler uses ParentBased logic to short-circuit the Parent's Sampling Decision, the XRay Sampling Statistics recording logic is not skipped when a Parent's Sampling Decision is found. This causes XRay Sampler to produce Sampling Statistics based on number of Spans (regardless of Parent's Sampling Decision), while XRay Sampler should produce Sampling Statistics based on number of Traces (aka the number of root Spans that makes the Sampling Decision) ### Description of changes: 1. Wrap entire XRay Sampler Internal Logic under a single ParentBased Sampler a. This will reduce lock contention and not over-count sampling statistics. 3. Remove use of redundant ParentBased Sampler logic for internal subcomponents for XRay Sampler. ### Testing: 1. Added unit tests to verify ParentBased sampling decisions 2. Original Sampler functionality tested locally using [Sampler Test Bed](https://github.com/aws-observability/aws-otel-community/tree/master/centralized-sampling-tests) By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Ping Xiang <[email protected]>
1 parent abca88c commit 0a936d5

File tree

6 files changed

+198
-78
lines changed

6 files changed

+198
-78
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule_applier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget
1313
from opentelemetry.context import Context
1414
from opentelemetry.sdk.resources import Resource
15-
from opentelemetry.sdk.trace.sampling import Decision, ParentBased, Sampler, SamplingResult, TraceIdRatioBased
15+
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult, TraceIdRatioBased
1616
from opentelemetry.semconv.resource import CloudPlatformValues, ResourceAttributes
1717
from opentelemetry.semconv.trace import SpanAttributes
1818
from opentelemetry.trace import Link, SpanKind
@@ -42,7 +42,7 @@ def __init__(
4242
self.__borrowing = False
4343

4444
if target is None:
45-
self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(self.sampling_rule.FixedRate))
45+
self.__fixed_rate_sampler = TraceIdRatioBased(self.sampling_rule.FixedRate)
4646
# Until targets are fetched, initialize as borrowing=True if there will be a quota > 0
4747
if self.sampling_rule.ReservoirSize > 0:
4848
self.__reservoir_sampler = self.__create_reservoir_sampler(quota=1)
@@ -55,7 +55,7 @@ def __init__(
5555
new_quota = target.ReservoirQuota if target.ReservoirQuota is not None else 0
5656
new_fixed_rate = target.FixedRate if target.FixedRate is not None else 0
5757
self.__reservoir_sampler = self.__create_reservoir_sampler(quota=new_quota)
58-
self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(new_fixed_rate))
58+
self.__fixed_rate_sampler = TraceIdRatioBased(new_fixed_rate)
5959
if target.ReservoirQuotaTTL is not None:
6060
self.__reservoir_expiry = self._clock.from_timestamp(target.ReservoirQuotaTTL)
6161
else:
@@ -159,7 +159,7 @@ def matches(self, resource: Resource, attributes: Attributes) -> bool:
159159
)
160160

161161
def __create_reservoir_sampler(self, quota: int) -> Sampler:
162-
return ParentBased(_RateLimitingSampler(quota, self._clock))
162+
return _RateLimitingSampler(quota, self._clock)
163163

164164
# pylint: disable=no-self-use
165165
def __get_service_type(self, resource: Resource) -> str:

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,48 @@
2424
DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000"
2525

2626

27+
# Wrapper class to ensure that all XRay Sampler Functionality in _AwsXRayRemoteSampler
28+
# uses ParentBased logic to respect the parent span's sampling decision
2729
class AwsXRayRemoteSampler(Sampler):
30+
def __init__(
31+
self,
32+
resource: Resource,
33+
endpoint: str = None,
34+
polling_interval: int = None,
35+
log_level=None,
36+
):
37+
self._root = ParentBased(
38+
_AwsXRayRemoteSampler(
39+
resource=resource, endpoint=endpoint, polling_interval=polling_interval, log_level=log_level
40+
)
41+
)
42+
43+
# pylint: disable=no-self-use
44+
@override
45+
def should_sample(
46+
self,
47+
parent_context: Optional[Context],
48+
trace_id: int,
49+
name: str,
50+
kind: SpanKind = None,
51+
attributes: Attributes = None,
52+
links: Sequence[Link] = None,
53+
trace_state: TraceState = None,
54+
) -> SamplingResult:
55+
return self._root.should_sample(
56+
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
57+
)
58+
59+
# pylint: disable=no-self-use
60+
@override
61+
def get_description(self) -> str:
62+
return f"AwsXRayRemoteSampler{{root:{self._root.get_description()}}}"
63+
64+
65+
# _AwsXRayRemoteSampler contains all core XRay Sampler Functionality,
66+
# however it is NOT Parent-based (e.g. Sample logic runs for each span)
67+
# Not intended for external use, use Parent-based `AwsXRayRemoteSampler` instead.
68+
class _AwsXRayRemoteSampler(Sampler):
2869
"""
2970
Remote Sampler for OpenTelemetry that gets sampling configurations from AWS X-Ray
3071
@@ -58,7 +99,7 @@ def __init__(
5899
self.__client_id = self.__generate_client_id()
59100
self._clock = _Clock()
60101
self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level)
61-
self.__fallback_sampler = ParentBased(_FallbackSampler(self._clock))
102+
self.__fallback_sampler = _FallbackSampler(self._clock)
62103

63104
self.__polling_interval = polling_interval
64105
self.__target_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS
@@ -114,7 +155,7 @@ def should_sample(
114155
# pylint: disable=no-self-use
115156
@override
116157
def get_description(self) -> str:
117-
description = "AwsXRayRemoteSampler{remote sampling with AWS X-Ray}"
158+
description = "_AwsXRayRemoteSampler{remote sampling with AWS X-Ray}"
118159
return description
119160

120161
def __get_and_update_sampling_rules(self) -> None:

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py

Lines changed: 110 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
from mock_clock import MockClock
1212

13-
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
13+
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler, _AwsXRayRemoteSampler
1414
from opentelemetry.sdk.resources import Resource
15+
from opentelemetry.sdk.trace import Tracer, TracerProvider
1516
from opentelemetry.sdk.trace.sampling import Decision
1617

1718
TEST_DIR = os.path.dirname(os.path.realpath(__file__))
@@ -55,27 +56,31 @@ def setUp(self):
5556
def tearDown(self):
5657
# Clean up timers
5758
if self.rs is not None:
58-
self.rs._rules_timer.cancel()
59-
self.rs._targets_timer.cancel()
59+
self.rs._root._root._rules_timer.cancel()
60+
self.rs._root._root._targets_timer.cancel()
6061

6162
def test_create_remote_sampler_with_empty_resource(self):
6263
self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty())
63-
self.assertIsNotNone(self.rs._rules_timer)
64-
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 300)
65-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
66-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
67-
self.assertTrue(len(self.rs._AwsXRayRemoteSampler__client_id), 24)
64+
self.assertIsNotNone(self.rs._root._root._rules_timer)
65+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 300)
66+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
67+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
68+
self.assertTrue(len(self.rs._root._root._AwsXRayRemoteSampler__client_id), 24)
6869

6970
def test_create_remote_sampler_with_populated_resource(self):
7071
self.rs = AwsXRayRemoteSampler(
7172
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
7273
)
73-
self.assertIsNotNone(self.rs._rules_timer)
74-
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 300)
75-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
76-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
77-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
78-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
74+
self.assertIsNotNone(self.rs._root._root._rules_timer)
75+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 300)
76+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
77+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
78+
self.assertEqual(
79+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name"
80+
)
81+
self.assertEqual(
82+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform"
83+
)
7984

8085
def test_create_remote_sampler_with_all_fields_populated(self):
8186
self.rs = AwsXRayRemoteSampler(
@@ -84,35 +89,39 @@ def test_create_remote_sampler_with_all_fields_populated(self):
8489
polling_interval=120,
8590
log_level=DEBUG,
8691
)
87-
self.assertIsNotNone(self.rs._rules_timer)
88-
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 120)
89-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
90-
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
92+
self.assertIsNotNone(self.rs._root._root._rules_timer)
93+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 120)
94+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
95+
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
9196
self.assertEqual(
92-
self.rs._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint,
97+
self.rs._root._root._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint,
9398
"http://abc.com/GetSamplingRules",
9499
)
95-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
96-
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
100+
self.assertEqual(
101+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name"
102+
)
103+
self.assertEqual(
104+
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform"
105+
)
97106

98107
@patch("requests.Session.post", side_effect=mocked_requests_get)
99108
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 2)
100109
def test_update_sampling_rules_and_targets_with_pollers_and_should_sample(self, mock_post=None):
101110
self.rs = AwsXRayRemoteSampler(
102111
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
103112
)
104-
self.assertEqual(self.rs._AwsXRayRemoteSampler__target_polling_interval, 2)
113+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval, 2)
105114

106115
time.sleep(1.0)
107116
self.assertEqual(
108-
self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName,
117+
self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName,
109118
"test",
110119
)
111120
self.assertEqual(self.rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.DROP)
112121

113122
# wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s)
114123
time.sleep(2.0)
115-
self.assertEqual(self.rs._AwsXRayRemoteSampler__target_polling_interval, 1000)
124+
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval, 1000)
116125
self.assertEqual(
117126
self.rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision,
118127
Decision.RECORD_AND_SAMPLE,
@@ -162,9 +171,9 @@ def test_multithreading_with_large_reservoir_with_otel_sdk(self, mock_post=None)
162171
threads[idx].join()
163172
sum_sampled += sampled_array[idx]
164173

165-
test_rule_applier = self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0]
174+
test_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0]
166175
self.assertEqual(
167-
test_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
176+
test_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
168177
100000,
169178
)
170179
self.assertEqual(sum_sampled, 100000)
@@ -183,19 +192,19 @@ def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None):
183192
# which will eat up more than 1 second of reservoir. Using MockClock we can freeze time
184193
# and pretend all thread jobs start and end at the exact same time,
185194
# assume and test exactly 1 second of reservoir (100 quota) only
186-
mock_clock: MockClock = self.rs._clock
195+
mock_clock: MockClock = self.rs._root._root._clock
187196

188197
time.sleep(1.0)
189198
mock_clock.add_time(1.0)
190-
self.assertEqual(mock_clock.now(), self.rs._clock.now())
199+
self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now())
191200
self.assertEqual(
192201
self.rs.should_sample(None, 0, "name", attributes=attributes).decision, Decision.RECORD_AND_SAMPLE
193202
)
194203

195204
# wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s)
196205
time.sleep(2.0)
197206
mock_clock.add_time(2.0)
198-
self.assertEqual(mock_clock.now(), self.rs._clock.now())
207+
self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now())
199208

200209
number_of_spans = 100
201210
thread_count = 1000
@@ -219,9 +228,79 @@ def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None):
219228
threads[idx].join()
220229
sum_sampled += sampled_array[idx]
221230

222-
default_rule_applier = self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
231+
default_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
223232
self.assertEqual(
224-
default_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
233+
default_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
225234
100,
226235
)
227236
self.assertEqual(sum_sampled, 100)
237+
238+
def test_get_description(self) -> str:
239+
self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler(resource=Resource.create({"service.name": "dummy_name"}))
240+
self.assertEqual(
241+
self.rs.get_description(),
242+
"AwsXRayRemoteSampler{root:ParentBased{root:_AwsXRayRemoteSampler{remote sampling with AWS X-Ray},remoteParentSampled:AlwaysOnSampler,remoteParentNotSampled:AlwaysOffSampler,localParentSampled:AlwaysOnSampler,localParentNotSampled:AlwaysOffSampler}}", # noqa: E501
243+
)
244+
245+
@patch("requests.Session.post", side_effect=mocked_requests_get)
246+
def test_parent_based_xray_sampler_updates_statistics_once_for_one_parent_span_with_two_children(
247+
self, mock_post=None
248+
):
249+
self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler(
250+
resource=Resource.create({"service.name": "use-default-sample-all-rule"})
251+
)
252+
time.sleep(1.0)
253+
254+
provider = TracerProvider(sampler=self.rs)
255+
tracer: Tracer = provider.get_tracer("test_tracer_1")
256+
257+
# child1 and child2 are child spans of root parent0
258+
# For AwsXRayRemoteSampler (ParentBased), expect only parent0 to update statistics
259+
with tracer.start_as_current_span("parent0") as _:
260+
with tracer.start_as_current_span("child1") as _:
261+
pass
262+
with tracer.start_as_current_span("child2") as _:
263+
pass
264+
default_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
265+
self.assertEqual(
266+
default_rule_applier._SamplingRuleApplier__statistics.RequestCount,
267+
1,
268+
)
269+
self.assertEqual(
270+
default_rule_applier._SamplingRuleApplier__statistics.SampleCount,
271+
1,
272+
)
273+
274+
@patch("requests.Session.post", side_effect=mocked_requests_get)
275+
def test_non_parent_based_xray_sampler_updates_statistics_thrice_for_one_parent_span_with_two_children(
276+
self, mock_post=None
277+
):
278+
non_parent_based_xray_sampler: _AwsXRayRemoteSampler = _AwsXRayRemoteSampler(
279+
resource=Resource.create({"service.name": "use-default-sample-all-rule"})
280+
)
281+
time.sleep(1.0)
282+
283+
provider = TracerProvider(sampler=non_parent_based_xray_sampler)
284+
tracer: Tracer = provider.get_tracer("test_tracer_2")
285+
286+
# child1 and child2 are child spans of root parent0
287+
# For _AwsXRayRemoteSampler (Non-ParentBased), expect all 3 spans to update statistics
288+
with tracer.start_as_current_span("parent0") as _:
289+
with tracer.start_as_current_span("child1") as _:
290+
pass
291+
with tracer.start_as_current_span("child2") as _:
292+
pass
293+
default_rule_applier = (
294+
non_parent_based_xray_sampler._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
295+
)
296+
self.assertEqual(
297+
default_rule_applier._SamplingRuleApplier__statistics.RequestCount,
298+
3,
299+
)
300+
self.assertEqual(
301+
default_rule_applier._SamplingRuleApplier__statistics.SampleCount,
302+
3,
303+
)
304+
305+
non_parent_based_xray_sampler._rules_timer.cancel()
306+
non_parent_based_xray_sampler._targets_timer.cancel()

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,15 @@ def test_update_sampling_targets(self):
136136
# quota should be 1 because of borrowing=true until targets are updated
137137
rule_applier_0 = rule_cache._RuleCache__rule_appliers[0]
138138
self.assertEqual(
139-
rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1
139+
rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, 1
140140
)
141-
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_2.FixedRate)
141+
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, sampling_rule_2.FixedRate)
142142

143143
rule_applier_1 = rule_cache._RuleCache__rule_appliers[1]
144144
self.assertEqual(
145-
rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1
145+
rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, 1
146146
)
147-
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_1.FixedRate)
147+
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, sampling_rule_1.FixedRate)
148148

149149
target_1 = {
150150
"FixedRate": 0.05,
@@ -179,17 +179,17 @@ def test_update_sampling_targets(self):
179179
# borrowing=false, use quota from targets
180180
rule_applier_0 = rule_cache._RuleCache__rule_appliers[0]
181181
self.assertEqual(
182-
rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
182+
rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
183183
target_2["ReservoirQuota"],
184184
)
185-
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_2["FixedRate"])
185+
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, target_2["FixedRate"])
186186

187187
rule_applier_1 = rule_cache._RuleCache__rule_appliers[1]
188188
self.assertEqual(
189-
rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
189+
rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
190190
target_1["ReservoirQuota"],
191191
)
192-
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_1["FixedRate"])
192+
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, target_1["FixedRate"])
193193

194194
# Test target response modified after Rule cache's last modified date
195195
target_response.LastRuleModification = mock_clock.now().timestamp() + 1

0 commit comments

Comments
 (0)