Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import Decision, ParentBased, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.semconv.resource import CloudPlatformValues, ResourceAttributes
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Link, SpanKind
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
self.__borrowing = False

if target is None:
self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(self.sampling_rule.FixedRate))
self.__fixed_rate_sampler = TraceIdRatioBased(self.sampling_rule.FixedRate)
# Until targets are fetched, initialize as borrowing=True if there will be a quota > 0
if self.sampling_rule.ReservoirSize > 0:
self.__reservoir_sampler = self.__create_reservoir_sampler(quota=1)
Expand All @@ -55,7 +55,7 @@ def __init__(
new_quota = target.ReservoirQuota if target.ReservoirQuota is not None else 0
new_fixed_rate = target.FixedRate if target.FixedRate is not None else 0
self.__reservoir_sampler = self.__create_reservoir_sampler(quota=new_quota)
self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(new_fixed_rate))
self.__fixed_rate_sampler = TraceIdRatioBased(new_fixed_rate)
if target.ReservoirQuotaTTL is not None:
self.__reservoir_expiry = self._clock.from_timestamp(target.ReservoirQuotaTTL)
else:
Expand Down Expand Up @@ -159,7 +159,7 @@ def matches(self, resource: Resource, attributes: Attributes) -> bool:
)

def __create_reservoir_sampler(self, quota: int) -> Sampler:
return ParentBased(_RateLimitingSampler(quota, self._clock))
return _RateLimitingSampler(quota, self._clock)

# pylint: disable=no-self-use
def __get_service_type(self, resource: Resource) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,48 @@
DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000"


# Wrapper class to ensure that all XRay Sampler Functionality in _AwsXRayRemoteSampler
# uses ParentBased logic to respect the parent span's sampling decision
class AwsXRayRemoteSampler(Sampler):
def __init__(
self,
resource: Resource,
endpoint: str = None,
polling_interval: int = None,
log_level=None,
):
self._root = ParentBased(
_AwsXRayRemoteSampler(
resource=resource, endpoint=endpoint, polling_interval=polling_interval, log_level=log_level
)
)

# pylint: disable=no-self-use
@override
def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
return self._root.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

# pylint: disable=no-self-use
@override
def get_description(self) -> str:
return f"AwsXRayRemoteSampler{{root:{self._root.get_description()}}}"


# _AwsXRayRemoteSampler contains all core XRay Sampler Functionality,
# however it is NOT Parent-based (e.g. Sample logic runs for each span)
# Not intended for external use, use Parent-based `AwsXRayRemoteSampler` instead.
class _AwsXRayRemoteSampler(Sampler):
"""
Remote Sampler for OpenTelemetry that gets sampling configurations from AWS X-Ray

Expand Down Expand Up @@ -58,7 +99,7 @@ def __init__(
self.__client_id = self.__generate_client_id()
self._clock = _Clock()
self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level)
self.__fallback_sampler = ParentBased(_FallbackSampler(self._clock))
self.__fallback_sampler = _FallbackSampler(self._clock)

self.__polling_interval = polling_interval
self.__target_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS
Expand Down Expand Up @@ -114,7 +155,7 @@ def should_sample(
# pylint: disable=no-self-use
@override
def get_description(self) -> str:
description = "AwsXRayRemoteSampler{remote sampling with AWS X-Ray}"
description = "_AwsXRayRemoteSampler{remote sampling with AWS X-Ray}"
return description

def __get_and_update_sampling_rules(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

from mock_clock import MockClock

from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler, _AwsXRayRemoteSampler
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import Tracer, TracerProvider
from opentelemetry.sdk.trace.sampling import Decision

TEST_DIR = os.path.dirname(os.path.realpath(__file__))
Expand Down Expand Up @@ -55,27 +56,31 @@ def setUp(self):
def tearDown(self):
# Clean up timers
if self.rs is not None:
self.rs._rules_timer.cancel()
self.rs._targets_timer.cancel()
self.rs._root._root._rules_timer.cancel()
self.rs._root._root._targets_timer.cancel()

def test_create_remote_sampler_with_empty_resource(self):
self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty())
self.assertIsNotNone(self.rs._rules_timer)
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 300)
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
self.assertTrue(len(self.rs._AwsXRayRemoteSampler__client_id), 24)
self.assertIsNotNone(self.rs._root._root._rules_timer)
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 300)
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
self.assertTrue(len(self.rs._root._root._AwsXRayRemoteSampler__client_id), 24)

def test_create_remote_sampler_with_populated_resource(self):
self.rs = AwsXRayRemoteSampler(
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
)
self.assertIsNotNone(self.rs._rules_timer)
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 300)
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
self.assertIsNotNone(self.rs._root._root._rules_timer)
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 300)
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
self.assertEqual(
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name"
)
self.assertEqual(
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform"
)

def test_create_remote_sampler_with_all_fields_populated(self):
self.rs = AwsXRayRemoteSampler(
Expand All @@ -84,35 +89,39 @@ def test_create_remote_sampler_with_all_fields_populated(self):
polling_interval=120,
log_level=DEBUG,
)
self.assertIsNotNone(self.rs._rules_timer)
self.assertEqual(self.rs._AwsXRayRemoteSampler__polling_interval, 120)
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(self.rs._AwsXRayRemoteSampler__resource)
self.assertIsNotNone(self.rs._root._root._rules_timer)
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__polling_interval, 120)
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(self.rs._root._root._AwsXRayRemoteSampler__resource)
self.assertEqual(
self.rs._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint,
self.rs._root._root._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint,
"http://abc.com/GetSamplingRules",
)
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
self.assertEqual(self.rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
self.assertEqual(
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name"
)
self.assertEqual(
self.rs._root._root._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform"
)

@patch("requests.Session.post", side_effect=mocked_requests_get)
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 2)
def test_update_sampling_rules_and_targets_with_pollers_and_should_sample(self, mock_post=None):
self.rs = AwsXRayRemoteSampler(
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
)
self.assertEqual(self.rs._AwsXRayRemoteSampler__target_polling_interval, 2)
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval, 2)

time.sleep(1.0)
self.assertEqual(
self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName,
self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName,
"test",
)
self.assertEqual(self.rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.DROP)

# wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s)
time.sleep(2.0)
self.assertEqual(self.rs._AwsXRayRemoteSampler__target_polling_interval, 1000)
self.assertEqual(self.rs._root._root._AwsXRayRemoteSampler__target_polling_interval, 1000)
self.assertEqual(
self.rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision,
Decision.RECORD_AND_SAMPLE,
Expand Down Expand Up @@ -162,9 +171,9 @@ def test_multithreading_with_large_reservoir_with_otel_sdk(self, mock_post=None)
threads[idx].join()
sum_sampled += sampled_array[idx]

test_rule_applier = self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0]
test_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0]
self.assertEqual(
test_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
test_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
100000,
)
self.assertEqual(sum_sampled, 100000)
Expand All @@ -183,19 +192,19 @@ def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None):
# which will eat up more than 1 second of reservoir. Using MockClock we can freeze time
# and pretend all thread jobs start and end at the exact same time,
# assume and test exactly 1 second of reservoir (100 quota) only
mock_clock: MockClock = self.rs._clock
mock_clock: MockClock = self.rs._root._root._clock

time.sleep(1.0)
mock_clock.add_time(1.0)
self.assertEqual(mock_clock.now(), self.rs._clock.now())
self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now())
self.assertEqual(
self.rs.should_sample(None, 0, "name", attributes=attributes).decision, Decision.RECORD_AND_SAMPLE
)

# wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s)
time.sleep(2.0)
mock_clock.add_time(2.0)
self.assertEqual(mock_clock.now(), self.rs._clock.now())
self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now())

number_of_spans = 100
thread_count = 1000
Expand All @@ -219,9 +228,79 @@ def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None):
threads[idx].join()
sum_sampled += sampled_array[idx]

default_rule_applier = self.rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
default_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
self.assertEqual(
default_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
default_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
100,
)
self.assertEqual(sum_sampled, 100)

def test_get_description(self) -> str:
self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler(resource=Resource.create({"service.name": "dummy_name"}))
self.assertEqual(
self.rs.get_description(),
"AwsXRayRemoteSampler{root:ParentBased{root:_AwsXRayRemoteSampler{remote sampling with AWS X-Ray},remoteParentSampled:AlwaysOnSampler,remoteParentNotSampled:AlwaysOffSampler,localParentSampled:AlwaysOnSampler,localParentNotSampled:AlwaysOffSampler}}", # noqa: E501
)

@patch("requests.Session.post", side_effect=mocked_requests_get)
def test_parent_based_xray_sampler_updates_statistics_once_for_one_parent_span_with_two_children(
self, mock_post=None
):
self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler(
resource=Resource.create({"service.name": "use-default-sample-all-rule"})
)
time.sleep(1.0)

provider = TracerProvider(sampler=self.rs)
tracer: Tracer = provider.get_tracer("test_tracer_1")

# child1 and child2 are child spans of root parent0
# For AwsXRayRemoteSampler (ParentBased), expect only parent0 to update statistics
with tracer.start_as_current_span("parent0") as _:
with tracer.start_as_current_span("child1") as _:
pass
with tracer.start_as_current_span("child2") as _:
pass
default_rule_applier = self.rs._root._root._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
self.assertEqual(
default_rule_applier._SamplingRuleApplier__statistics.RequestCount,
1,
)
self.assertEqual(
default_rule_applier._SamplingRuleApplier__statistics.SampleCount,
1,
)

@patch("requests.Session.post", side_effect=mocked_requests_get)
def test_non_parent_based_xray_sampler_updates_statistics_thrice_for_one_parent_span_with_two_children(
self, mock_post=None
):
non_parent_based_xray_sampler: _AwsXRayRemoteSampler = _AwsXRayRemoteSampler(
resource=Resource.create({"service.name": "use-default-sample-all-rule"})
)
time.sleep(1.0)

provider = TracerProvider(sampler=non_parent_based_xray_sampler)
tracer: Tracer = provider.get_tracer("test_tracer_2")

# child1 and child2 are child spans of root parent0
# For _AwsXRayRemoteSampler (Non-ParentBased), expect all 3 spans to update statistics
with tracer.start_as_current_span("parent0") as _:
with tracer.start_as_current_span("child1") as _:
pass
with tracer.start_as_current_span("child2") as _:
pass
default_rule_applier = (
non_parent_based_xray_sampler._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1]
)
self.assertEqual(
default_rule_applier._SamplingRuleApplier__statistics.RequestCount,
3,
)
self.assertEqual(
default_rule_applier._SamplingRuleApplier__statistics.SampleCount,
3,
)

non_parent_based_xray_sampler._rules_timer.cancel()
non_parent_based_xray_sampler._targets_timer.cancel()
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ def test_update_sampling_targets(self):
# quota should be 1 because of borrowing=true until targets are updated
rule_applier_0 = rule_cache._RuleCache__rule_appliers[0]
self.assertEqual(
rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1
rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, 1
)
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_2.FixedRate)
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, sampling_rule_2.FixedRate)

rule_applier_1 = rule_cache._RuleCache__rule_appliers[1]
self.assertEqual(
rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1
rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, 1
)
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_1.FixedRate)
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, sampling_rule_1.FixedRate)

target_1 = {
"FixedRate": 0.05,
Expand Down Expand Up @@ -179,17 +179,17 @@ def test_update_sampling_targets(self):
# borrowing=false, use quota from targets
rule_applier_0 = rule_cache._RuleCache__rule_appliers[0]
self.assertEqual(
rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
target_2["ReservoirQuota"],
)
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_2["FixedRate"])
self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, target_2["FixedRate"])

rule_applier_1 = rule_cache._RuleCache__rule_appliers[1]
self.assertEqual(
rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota,
rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota,
target_1["ReservoirQuota"],
)
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_1["FixedRate"])
self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, target_1["FixedRate"])

# Test target response modified after Rule cache's last modified date
target_response.LastRuleModification = mock_clock.now().timestamp() + 1
Expand Down
Loading
Loading