diff --git a/CHANGELOG.md b/CHANGELOG.md index 024990c91d..0ff85523ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- `opentelemetry-sdk-extension-aws` Add caching, matching, and targets logic to complete AWS X-Ray Remote Sampler implementation + ([#3366](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3761)) + ## Version 1.38.0/0.59b0 (2025-10-16) ### Fixed diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/README.rst b/sdk-extension/opentelemetry-sdk-extension-aws/README.rst index 529a80868e..2028dcd0a0 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/README.rst +++ b/sdk-extension/opentelemetry-sdk-extension-aws/README.rst @@ -74,6 +74,29 @@ populate `resource` attributes by creating a `TraceProvider` using the `AwsEc2Re Refer to each detectors' docstring to determine any possible requirements for that detector. + +Usage (AWS X-Ray Remote Sampler) +-------------------------------- + +Use the provided AWS X-Ray Remote Sampler by setting this sampler in your instrumented application: + +.. code-block:: python + + from opentelemetry.sdk.extension.aws.trace.sampler import AwsXRayRemoteSampler + from opentelemetry import trace + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.semconv.resource import ResourceAttributes + from opentelemetry.util.types import Attributes + + resource = Resource.create(attributes={ + ResourceAttributes.SERVICE_NAME: "myService", + ResourceAttributes.CLOUD_PLATFORM: "aws_ec2", + }) + xraySampler = AwsXRayRemoteSampler(resource=resource, polling_interval=300) + trace.set_tracer_provider(TracerProvider(sampler=xraySampler)) + + References ---------- diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/pyproject.toml b/sdk-extension/opentelemetry-sdk-extension-aws/pyproject.toml index 1dc8a2d68b..087ec12f43 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/pyproject.toml +++ b/sdk-extension/opentelemetry-sdk-extension-aws/pyproject.toml @@ -42,9 +42,8 @@ aws_eks = "opentelemetry.sdk.extension.aws.resource.eks:AwsEksResourceDetector" aws_elastic_beanstalk = "opentelemetry.sdk.extension.aws.resource.beanstalk:AwsBeanstalkResourceDetector" aws_lambda = "opentelemetry.sdk.extension.aws.resource._lambda:AwsLambdaResourceDetector" -# TODO: Uncomment this when Sampler implementation is complete -# [project.entry-points.opentelemetry_sampler] -# aws_xray_remote_sampler = "opentelemetry.sdk.extension.aws.trace.sampler:AwsXRayRemoteSampler" +[project.entry-points.opentelemetry_sampler] +aws_xray_remote_sampler = "opentelemetry.sdk.extension.aws.trace.sampler:AwsXRayRemoteSampler" [project.urls] Homepage = "https://github.com/open-telemetry/opentelemetry-python-contrib/tree/main/sdk-extension/opentelemetry-sdk-extension-aws" diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/__init__.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/__init__.py index 623122fc28..e40a213ad4 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/__init__.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/__init__.py @@ -14,7 +14,7 @@ # pylint: disable=no-name-in-module from opentelemetry.sdk.extension.aws.trace.sampler.aws_xray_remote_sampler import ( - _AwsXRayRemoteSampler, + AwsXRayRemoteSampler, ) -__all__ = ["_AwsXRayRemoteSampler"] +__all__ = ["AwsXRayRemoteSampler"] diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_fallback_sampler.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_fallback_sampler.py new file mode 100644 index 0000000000..3693ca5bee --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_fallback_sampler.py @@ -0,0 +1,79 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Sequence + +# pylint: disable=no-name-in-module +from opentelemetry.context import Context +from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock +from opentelemetry.sdk.extension.aws.trace.sampler._rate_limiting_sampler import ( + _RateLimitingSampler, +) +from opentelemetry.sdk.trace.sampling import ( + Decision, + Sampler, + SamplingResult, + TraceIdRatioBased, +) +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + + +class _FallbackSampler(Sampler): + def __init__(self, clock: _Clock): + self.__rate_limiting_sampler = _RateLimitingSampler(1, clock) + self.__fixed_rate_sampler = TraceIdRatioBased(0.05) + + def should_sample( + self, + parent_context: Context | None, + trace_id: int, + name: str, + kind: SpanKind | None = None, + attributes: Attributes | None = None, + links: Sequence["Link"] | None = None, + trace_state: TraceState | None = None, + ) -> "SamplingResult": + sampling_result = self.__rate_limiting_sampler.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + if sampling_result.decision is not Decision.DROP: + return sampling_result + return self.__fixed_rate_sampler.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + + # pylint: disable=no-self-use + def get_description(self) -> str: + description = "FallbackSampler{fallback sampling with sampling config of 1 req/sec and 5% of additional requests}" + return description diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_matcher.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_matcher.py new file mode 100644 index 0000000000..a14a1be5bf --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_matcher.py @@ -0,0 +1,97 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import re + +from opentelemetry.semconv.resource import CloudPlatformValues +from opentelemetry.util.types import Attributes, AttributeValue + +cloud_platform_mapping = { + CloudPlatformValues.AWS_LAMBDA.value: "AWS::Lambda::Function", + CloudPlatformValues.AWS_ELASTIC_BEANSTALK.value: "AWS::ElasticBeanstalk::Environment", + CloudPlatformValues.AWS_EC2.value: "AWS::EC2::Instance", + CloudPlatformValues.AWS_ECS.value: "AWS::ECS::Container", + CloudPlatformValues.AWS_EKS.value: "AWS::EKS::Container", +} + + +class _Matcher: + @staticmethod + def wild_card_match( + text: AttributeValue | None = None, pattern: str | None = None + ) -> bool: + if pattern == "*": + return True + if not isinstance(text, str) or pattern is None: + return False + if len(pattern) == 0: + return len(text) == 0 + for char in pattern: + if char in ("*", "?"): + return ( + re.fullmatch(_Matcher.to_regex_pattern(pattern), text) + is not None + ) + return pattern == text + + @staticmethod + def to_regex_pattern(rule_pattern: str) -> str: + token_start = -1 + regex_pattern = "" + for index, char in enumerate(rule_pattern): + char = rule_pattern[index] + if char in ("*", "?"): + if token_start != -1: + regex_pattern += re.escape(rule_pattern[token_start:index]) + token_start = -1 + if char == "*": + regex_pattern += ".*" + else: + regex_pattern += "." + else: + if token_start == -1: + token_start = index + if token_start != -1: + regex_pattern += re.escape(rule_pattern[token_start:]) + return regex_pattern + + @staticmethod + def attribute_match( + attributes: Attributes | None = None, + rule_attributes: dict[str, str] | None = None, + ) -> bool: + if rule_attributes is None or len(rule_attributes) == 0: + return True + if ( + attributes is None + or len(attributes) == 0 + or len(rule_attributes) > len(attributes) + ): + return False + + matched_count = 0 + for key, val in attributes.items(): + text_to_match = val + pattern = rule_attributes.get(key, None) + if pattern is None: + continue + if _Matcher.wild_card_match(text_to_match, pattern): + matched_count += 1 + return matched_count == len(rule_attributes) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rate_limiter.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rate_limiter.py new file mode 100644 index 0000000000..d9b84cfc4e --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rate_limiter.py @@ -0,0 +1,69 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from decimal import Decimal +from threading import Lock + +# pylint: disable=no-name-in-module +from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock + + +class _RateLimiter: + def __init__(self, max_balance_in_seconds: int, quota: int, clock: _Clock): + # max_balance_in_seconds is usually 1 + # pylint: disable=invalid-name + self.MAX_BALANCE_MILLIS = Decimal(max_balance_in_seconds * 1000.0) + self._clock = clock + + self._quota = Decimal(quota) + self.__wallet_floor_millis = Decimal( + self._clock.now().timestamp() * 1000.0 + ) + # current "wallet_balance" would be ceiling - floor + + self.__lock = Lock() + + def try_spend(self, cost: float) -> bool: + if self._quota == 0: + return False + + quota_per_millis = self._quota / Decimal(1000.0) + + # assume divide by zero not possible + cost_in_millis = Decimal(cost) / quota_per_millis + + with self.__lock: + wallet_ceiling_millis = Decimal( + self._clock.now().timestamp() * 1000.0 + ) + current_balance_millis = ( + wallet_ceiling_millis - self.__wallet_floor_millis + ) + current_balance_millis = min( + current_balance_millis, self.MAX_BALANCE_MILLIS + ) + pending_remaining_balance_millis = ( + current_balance_millis - cost_in_millis + ) + if pending_remaining_balance_millis >= 0: + self.__wallet_floor_millis = ( + wallet_ceiling_millis - pending_remaining_balance_millis + ) + return True + # No changes to the wallet state + return False diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rate_limiting_sampler.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rate_limiting_sampler.py new file mode 100644 index 0000000000..1055b91b9c --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rate_limiting_sampler.py @@ -0,0 +1,64 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Sequence + +# pylint: disable=no-name-in-module +from opentelemetry.context import Context +from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock +from opentelemetry.sdk.extension.aws.trace.sampler._rate_limiter import ( + _RateLimiter, +) +from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + + +class _RateLimitingSampler(Sampler): + def __init__(self, quota: int, clock: _Clock): + self.__quota = quota + self.__reservoir = _RateLimiter(1, quota, clock) + + def should_sample( + self, + parent_context: Context | None, + trace_id: int, + name: str, + kind: SpanKind | None = None, + attributes: Attributes | None = None, + links: Sequence["Link"] | None = None, + trace_state: TraceState | None = None, + ) -> "SamplingResult": + if self.__reservoir.try_spend(1): + return SamplingResult( + decision=Decision.RECORD_AND_SAMPLE, + attributes=attributes, + trace_state=trace_state, + ) + return SamplingResult( + decision=Decision.DROP, + attributes=attributes, + trace_state=trace_state, + ) + + def get_description(self) -> str: + description = f"RateLimitingSampler{{rate limiting sampling with sampling config of {self.__quota} req/sec and 0% of additional requests}}" + return description diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rule_cache.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rule_cache.py new file mode 100644 index 0000000000..99f3957d1e --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_rule_cache.py @@ -0,0 +1,202 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from logging import getLogger +from threading import Lock +from typing import Dict, List, Sequence + +# pylint: disable=no-name-in-module +from opentelemetry.context import Context +from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock +from opentelemetry.sdk.extension.aws.trace.sampler._fallback_sampler import ( + _FallbackSampler, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_rule import ( + _SamplingRule, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_rule_applier import ( + _SamplingRuleApplier, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_target import ( + _SamplingTarget, + _SamplingTargetResponse, +) +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import SamplingResult +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + +_logger = getLogger(__name__) + +CACHE_TTL_SECONDS = 3600 +DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10 + + +class _RuleCache: + def __init__( + self, + resource: Resource, + fallback_sampler: _FallbackSampler, + client_id: str, + clock: _Clock, + lock: Lock, + ): + self.__client_id = client_id + self.__rule_appliers: List[_SamplingRuleApplier] = [] + self.__cache_lock = lock + self.__resource = resource + self._fallback_sampler = fallback_sampler + self._clock = clock + self._last_modified = self._clock.now() + + def should_sample( + self, + parent_context: Context | None, + trace_id: int, + name: str, + kind: SpanKind | None = None, + attributes: Attributes | None = None, + links: Sequence["Link"] | None = None, + trace_state: TraceState | None = None, + ) -> "SamplingResult": + rule_applier: _SamplingRuleApplier + for rule_applier in self.__rule_appliers: + if rule_applier.matches(self.__resource, attributes): + return rule_applier.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + + _logger.debug("No sampling rules were matched") + # Should not ever reach fallback sampler as default rule is able to match + return self._fallback_sampler.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + + def update_sampling_rules( + self, new_sampling_rules: List[_SamplingRule] + ) -> None: + new_sampling_rules.sort() + temp_rule_appliers: List[_SamplingRuleApplier] = [] + for sampling_rule in new_sampling_rules: + if sampling_rule.RuleName == "": + _logger.debug( + "sampling rule without rule name is not supported" + ) + continue + if sampling_rule.Version != 1: + _logger.debug( + "sampling rule without Version 1 is not supported: RuleName: %s", + sampling_rule.RuleName, + ) + continue + temp_rule_appliers.append( + _SamplingRuleApplier( + sampling_rule, self.__client_id, self._clock + ) + ) + + with self.__cache_lock: + # map list of rule appliers by each applier's sampling_rule name + rule_applier_map: Dict[str, _SamplingRuleApplier] = { + applier.sampling_rule.RuleName: applier + for applier in self.__rule_appliers + } + + # If a sampling rule has not changed, keep its respective applier in the cache. + new_applier: _SamplingRuleApplier + for index, new_applier in enumerate(temp_rule_appliers): + rule_name_to_check = new_applier.sampling_rule.RuleName + if rule_name_to_check in rule_applier_map: + old_applier = rule_applier_map[rule_name_to_check] + if new_applier.sampling_rule == old_applier.sampling_rule: + temp_rule_appliers[index] = old_applier + self.__rule_appliers = temp_rule_appliers + self._last_modified = self._clock.now() + + def update_sampling_targets( + self, sampling_targets_response: _SamplingTargetResponse + ): + targets: List[_SamplingTarget] = ( + sampling_targets_response.SamplingTargetDocuments + ) + + with self.__cache_lock: + next_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS + min_polling_interval = None + + target_map: Dict[str, _SamplingTarget] = { + target.RuleName: target for target in targets + } + + new_appliers: List[_SamplingRuleApplier] = [] + applier: _SamplingRuleApplier + for applier in self.__rule_appliers: + if applier.sampling_rule.RuleName in target_map: + target = target_map[applier.sampling_rule.RuleName] + new_appliers.append(applier.with_target(target)) + + if target.Interval is not None: + if ( + min_polling_interval is None + or min_polling_interval > target.Interval + ): + min_polling_interval = target.Interval + else: + new_appliers.append(applier) + + self.__rule_appliers = new_appliers + + if min_polling_interval is not None: + next_polling_interval = min_polling_interval + + last_rule_modification = self._clock.from_timestamp( + sampling_targets_response.LastRuleModification + ) + refresh_rules = last_rule_modification > self._last_modified + + return (refresh_rules, next_polling_interval) + + def get_all_statistics(self): + all_statistics: list[dict[str, "str | float | int"]] = [] + applier: _SamplingRuleApplier + for applier in self.__rule_appliers: + all_statistics.append(applier.get_then_reset_statistics()) + return all_statistics + + def expired(self) -> bool: + with self.__cache_lock: + return ( + self._clock.now() + > self._last_modified + + self._clock.time_delta(seconds=CACHE_TTL_SECONDS) + ) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_sampling_rule_applier.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_sampling_rule_applier.py index 332b274de2..79b38e27c7 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_sampling_rule_applier.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/_sampling_rule_applier.py @@ -18,8 +18,20 @@ from __future__ import annotations +from threading import Lock +from typing import Sequence +from urllib.parse import urlparse + # pylint: disable=no-name-in-module +from opentelemetry.context import Context from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock +from opentelemetry.sdk.extension.aws.trace.sampler._matcher import ( + _Matcher, + cloud_platform_mapping, +) +from opentelemetry.sdk.extension.aws.trace.sampler._rate_limiting_sampler import ( + _RateLimitingSampler, +) from opentelemetry.sdk.extension.aws.trace.sampler._sampling_rule import ( _SamplingRule, ) @@ -29,6 +41,21 @@ from opentelemetry.sdk.extension.aws.trace.sampler._sampling_target import ( _SamplingTarget, ) +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import ( + Decision, + Sampler, + SamplingResult, + TraceIdRatioBased, +) +from opentelemetry.semconv.resource import ( + CloudPlatformValues, + ResourceAttributes, +) +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes, AttributeValue class _SamplingRuleApplier: @@ -40,8 +67,250 @@ def __init__( statistics: _SamplingStatisticsDocument | None = None, target: _SamplingTarget | None = None, ): - self.__client_id = client_id # pylint: disable=W0238 + self.__client_id = client_id self._clock = clock self.sampling_rule = sampling_rule - # (TODO) Just store Sampling Rules for now, rest of implementation for later + if statistics is None: + self.__statistics = _SamplingStatisticsDocument( + self.__client_id, self.sampling_rule.RuleName + ) + else: + self.__statistics = statistics + self.__statistics_lock = Lock() + + self.__borrowing = False + + if target is None: + self.__fixed_rate_sampler = TraceIdRatioBased( + self.sampling_rule.FixedRate + ) + # Until targets are fetched, initialize as borrowing=True if there will be a quota > 0 + if self.sampling_rule.ReservoirSize > 0: + self.__reservoir_sampler = self.__create_reservoir_sampler( + quota=1 + ) + self.__borrowing = True + else: + self.__reservoir_sampler = self.__create_reservoir_sampler( + quota=0 + ) + # No targets are present, borrow until the end of time if there is any quota + self.__reservoir_expiry = self._clock.max() + else: + new_quota = ( + target.ReservoirQuota + if target.ReservoirQuota is not None + else 0 + ) + new_fixed_rate = target.FixedRate + self.__reservoir_sampler = self.__create_reservoir_sampler( + quota=new_quota + ) + self.__fixed_rate_sampler = TraceIdRatioBased(new_fixed_rate) + if target.ReservoirQuotaTTL is not None: + self.__reservoir_expiry = self._clock.from_timestamp( + target.ReservoirQuotaTTL + ) + else: + # assume expired if no TTL + self.__reservoir_expiry = self._clock.now() + + def should_sample( + self, + parent_context: Context | None, + trace_id: int, + name: str, + kind: SpanKind | None = None, + attributes: Attributes | None = None, + links: Sequence["Link"] | None = None, + trace_state: TraceState | None = None, + ) -> "SamplingResult": + has_borrowed = False + has_sampled = False + sampling_result = SamplingResult( + decision=Decision.DROP, + attributes=attributes, + trace_state=trace_state, + ) + + reservoir_expired: bool = self._clock.now() >= self.__reservoir_expiry + if not reservoir_expired: + sampling_result = self.__reservoir_sampler.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + + if sampling_result.decision is not Decision.DROP: + has_borrowed = self.__borrowing + has_sampled = True + else: + sampling_result = self.__fixed_rate_sampler.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + if sampling_result.decision is not Decision.DROP: + has_sampled = True + + with self.__statistics_lock: + self.__statistics.RequestCount += 1 + self.__statistics.BorrowCount += 1 if has_borrowed else 0 + self.__statistics.SampleCount += 1 if has_sampled else 0 + + return sampling_result + + def get_then_reset_statistics(self): + with self.__statistics_lock: + old_stats = self.__statistics + self.__statistics = _SamplingStatisticsDocument( + self.__client_id, self.sampling_rule.RuleName + ) + + return old_stats.snapshot(self._clock) + + def with_target(self, target: _SamplingTarget) -> "_SamplingRuleApplier": + new_applier = _SamplingRuleApplier( + self.sampling_rule, + self.__client_id, + self._clock, + self.__statistics, + target, + ) + return new_applier + + def matches(self, resource: Resource, attributes: Attributes) -> bool: + url_path: AttributeValue | None = None + url_full: AttributeValue | None = None + http_request_method: AttributeValue | None = None + server_address: AttributeValue | None = None + service_name: AttributeValue | None = None + + if attributes is not None: + # If `URL_PATH/URL_FULL/HTTP_REQUEST_METHOD/SERVER_ADDRESS` are not populated + # also check `HTTP_TARGET/HTTP_URL/HTTP_METHOD/HTTP_HOST` respectively as backup + url_path = attributes.get( + SpanAttributes.URL_PATH, + attributes.get(SpanAttributes.HTTP_TARGET, None), + ) + url_full = attributes.get( + SpanAttributes.URL_FULL, + attributes.get(SpanAttributes.HTTP_URL, None), + ) + http_request_method = attributes.get( + SpanAttributes.HTTP_REQUEST_METHOD, + attributes.get(SpanAttributes.HTTP_METHOD, None), + ) + server_address = attributes.get( + SpanAttributes.SERVER_ADDRESS, + attributes.get(SpanAttributes.HTTP_HOST, None), + ) + + # Resource shouldn't be none as it should default to empty resource + if resource is not None: + service_name = resource.attributes.get( + ResourceAttributes.SERVICE_NAME, "" + ) + + # target may be in url + if url_path is None and isinstance(url_full, str): + scheme_end_index = url_full.find("://") + # For network calls, URL usually has `scheme://host[:port][path][?query][#fragment]` format + # Per spec, url.full is always populated with scheme:// + # If scheme is not present, assume it's bad instrumentation and ignore. + if scheme_end_index > -1: + # urlparse("scheme://netloc/path;parameters?query#fragment") + url_path = urlparse(url_full).path + if url_path == "": + url_path = "/" + elif url_path is None and url_full is None: + # When missing, the URL Path is assumed to be / + url_path = "/" + + return ( + _Matcher.attribute_match(attributes, self.sampling_rule.Attributes) + and _Matcher.wild_card_match(url_path, self.sampling_rule.URLPath) + and _Matcher.wild_card_match( + http_request_method, self.sampling_rule.HTTPMethod + ) + and _Matcher.wild_card_match( + server_address, self.sampling_rule.Host + ) + and _Matcher.wild_card_match( + service_name, self.sampling_rule.ServiceName + ) + and _Matcher.wild_card_match( + self.__get_service_type(resource), + self.sampling_rule.ServiceType, + ) + and _Matcher.wild_card_match( + self.__get_arn(resource, attributes), + self.sampling_rule.ResourceARN, + ) + ) + + def __create_reservoir_sampler(self, quota: int) -> Sampler: + return _RateLimitingSampler(quota, self._clock) + + # pylint: disable=no-self-use + def __get_service_type(self, resource: Resource) -> str: + if resource is None: + return "" + + cloud_platform = resource.attributes.get( + ResourceAttributes.CLOUD_PLATFORM, None + ) + if not isinstance(cloud_platform, str): + return "" + + return cloud_platform_mapping.get(cloud_platform, "") + + def __get_arn( + self, resource: Resource, attributes: Attributes + ) -> AttributeValue: + if resource is not None: + arn = resource.attributes.get( + ResourceAttributes.AWS_ECS_CONTAINER_ARN, None + ) + if arn is not None: + return arn + if ( + resource is not None + and resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM) + == CloudPlatformValues.AWS_LAMBDA.value + ): + return self.__get_lambda_arn(resource, attributes) + return "" + + def __get_lambda_arn( + self, resource: Resource, attributes: Attributes + ) -> AttributeValue: + arn = resource.attributes.get( + ResourceAttributes.CLOUD_RESOURCE_ID, + resource.attributes.get(ResourceAttributes.FAAS_ID, None), + ) + if arn is not None: + return arn + + if attributes is None: + return "" + + # Note from `SpanAttributes.CLOUD_RESOURCE_ID`: + # "On some cloud providers, it may not be possible to determine the full ID at startup, + # so it may be necessary to set cloud.resource_id as a span attribute instead." + arn = attributes.get( + SpanAttributes.CLOUD_RESOURCE_ID, attributes.get("faas.id", None) + ) + if arn is not None: + return arn + + return "" diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/aws_xray_remote_sampler.py b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/aws_xray_remote_sampler.py index 0f34a83be3..002a40e55b 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/aws_xray_remote_sampler.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/src/opentelemetry/sdk/extension/aws/trace/sampler/aws_xray_remote_sampler.py @@ -20,7 +20,7 @@ import random from logging import getLogger -from threading import Timer +from threading import Lock, Timer from typing import Sequence from typing_extensions import override @@ -31,9 +31,16 @@ DEFAULT_SAMPLING_PROXY_ENDPOINT, _AwsXRaySamplingClient, ) +from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock +from opentelemetry.sdk.extension.aws.trace.sampler._fallback_sampler import ( + _FallbackSampler, +) +from opentelemetry.sdk.extension.aws.trace.sampler._rule_cache import ( + DEFAULT_TARGET_POLLING_INTERVAL_SECONDS, + _RuleCache, +) from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.sampling import ( - Decision, ParentBased, Sampler, SamplingResult, @@ -47,12 +54,9 @@ DEFAULT_RULES_POLLING_INTERVAL_SECONDS = 300 -# WORK IN PROGRESS -# TODO: Rename to AwsXRayRemoteSampler when the implementation is complete and is ready to use -# # Wrapper class to ensure that all XRay Sampler Functionality in _InternalAwsXRayRemoteSampler # uses ParentBased logic to respect the parent span's sampling decision -class _AwsXRayRemoteSampler(Sampler): +class AwsXRayRemoteSampler(Sampler): def __init__( self, resource: Resource, @@ -90,14 +94,11 @@ def should_sample( trace_state=trace_state, ) - # pylint: disable=no-self-use @override def get_description(self) -> str: return f"AwsXRayRemoteSampler{{root:{self._root.get_description()}}}" -# WORK IN PROGRESS -# # _InternalAwsXRayRemoteSampler contains all core XRay Sampler Functionality, # however it is NOT Parent-based (e.g. Sample logic runs for each span) # Not intended for external use, use Parent-based `AwsXRayRemoteSampler` instead. @@ -136,20 +137,36 @@ def __init__( ) polling_interval = DEFAULT_RULES_POLLING_INTERVAL_SECONDS + self.__client_id = self.__generate_client_id() + self._clock = _Clock() self.__xray_client = _AwsXRaySamplingClient( endpoint, log_level=log_level ) + self.__fallback_sampler = _FallbackSampler(self._clock) self.__polling_interval = polling_interval + self.__target_polling_interval = ( + DEFAULT_TARGET_POLLING_INTERVAL_SECONDS + ) self.__rule_polling_jitter = random.uniform(0.0, 5.0) + self.__target_polling_jitter = random.uniform(0.0, 0.1) if resource is not None: - self.__resource = resource # pylint: disable=W0238 + self.__resource = resource else: _logger.warning( "OTel Resource provided is `None`. Defaulting to empty resource" ) - self.__resource = Resource.get_empty() # pylint: disable=W0238 + self.__resource = Resource.get_empty() + + self.__rule_cache_lock = Lock() + self.__rule_cache = _RuleCache( + self.__resource, + self.__fallback_sampler, + self.__client_id, + self._clock, + self.__rule_cache_lock, + ) # Schedule the next rule poll now # Python Timers only run once, so they need to be recreated for every poll @@ -157,9 +174,14 @@ def __init__( self._rules_timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed self._rules_timer.start() - # (TODO) set up the target poller to go off once after the default interval. Subsequent polls may use new intervals. + # set up the target poller to go off once after the default interval. Subsequent polls may use new intervals. + self._targets_timer = Timer( + self.__target_polling_interval + self.__target_polling_jitter, + self.__start_sampling_target_poller, + ) + self._targets_timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed + self._targets_timer.start() - # pylint: disable=no-self-use @override def should_sample( self, @@ -171,9 +193,27 @@ def should_sample( links: Sequence["Link"] | None = None, trace_state: TraceState | None = None, ) -> "SamplingResult": - return SamplingResult( - decision=Decision.DROP, + if self.__rule_cache.expired(): + _logger.debug( + "Rule cache is expired so using fallback sampling strategy" + ) + return self.__fallback_sampler.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + + return self.__rule_cache.should_sample( + parent_context, + trace_id, + name, + kind=kind, attributes=attributes, + links=links, trace_state=trace_state, ) @@ -186,8 +226,8 @@ def get_description(self) -> str: return description def __get_and_update_sampling_rules(self) -> None: - sampling_rules = self.__xray_client.get_sampling_rules() # pylint: disable=W0612 # noqa: F841 - # (TODO) update rules cache with sampling rules + sampling_rules = self.__xray_client.get_sampling_rules() + self.__rule_cache.update_sampling_rules(sampling_rules) def __start_sampling_rule_poller(self) -> None: self.__get_and_update_sampling_rules() @@ -198,3 +238,35 @@ def __start_sampling_rule_poller(self) -> None: ) self._rules_timer.daemon = True self._rules_timer.start() + + def __get_and_update_sampling_targets(self) -> None: + all_statistics = self.__rule_cache.get_all_statistics() + sampling_targets_response = self.__xray_client.get_sampling_targets( + all_statistics + ) + refresh_rules, min_polling_interval = ( + self.__rule_cache.update_sampling_targets( + sampling_targets_response + ) + ) + if refresh_rules: + self.__get_and_update_sampling_rules() + if min_polling_interval is not None: # type: ignore + self.__target_polling_interval = min_polling_interval + + def __start_sampling_target_poller(self) -> None: + self.__get_and_update_sampling_targets() + # Schedule the next sampling targets poll + self._targets_timer = Timer( + self.__target_polling_interval + self.__target_polling_jitter, + self.__start_sampling_target_poller, + ) + self._targets_timer.daemon = True + self._targets_timer.start() + + def __generate_client_id(self) -> str: + hex_chars = "0123456789abcdef" + client_id_array: list[str] = [] + for _ in range(0, 24): + client_id_array.append(random.choice(hex_chars)) + return "".join(client_id_array) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_aws_xray_remote_sampler.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_aws_xray_remote_sampler.py index 2eb236b3cd..e7290fa351 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_aws_xray_remote_sampler.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_aws_xray_remote_sampler.py @@ -18,17 +18,25 @@ import json import os +import threading +import time from logging import DEBUG from unittest import TestCase from unittest.mock import patch +from pytest import mark + # pylint: disable=no-name-in-module from opentelemetry.sdk.extension.aws.trace.sampler.aws_xray_remote_sampler import ( - _AwsXRayRemoteSampler, + AwsXRayRemoteSampler, + _InternalAwsXRayRemoteSampler, ) from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import Tracer, TracerProvider from opentelemetry.sdk.trace.sampling import Decision +from ._mock_clock import MockClock + TEST_DIR = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = os.path.join(TEST_DIR, "data") @@ -84,15 +92,10 @@ def tearDown(self): # Clean up timers if self.rs is not None: self.rs._root._root._rules_timer.cancel() + self.rs._root._root._targets_timer.cancel() - @patch( - "opentelemetry.sdk.extension.aws.trace.sampler._aws_xray_sampling_client._AwsXRaySamplingClient.get_sampling_rules", - return_value=None, - ) - def test_create_remote_sampler_with_empty_resource( - self, mocked_get_sampling_rules - ): - self.rs = _AwsXRayRemoteSampler(resource=Resource.get_empty()) + def test_create_remote_sampler_with_empty_resource(self): + self.rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) self.assertIsNotNone(self.rs._root._root._rules_timer) self.assertEqual( self.rs._root._root._InternalAwsXRayRemoteSampler__polling_interval, @@ -104,15 +107,13 @@ def test_create_remote_sampler_with_empty_resource( self.assertIsNotNone( self.rs._root._root._InternalAwsXRayRemoteSampler__resource ) + self.assertTrue( + len(self.rs._root._root._InternalAwsXRayRemoteSampler__client_id), + 24, + ) - @patch( - "opentelemetry.sdk.extension.aws.trace.sampler._aws_xray_sampling_client._AwsXRaySamplingClient.get_sampling_rules", - return_value=None, - ) - def test_create_remote_sampler_with_populated_resource( - self, mocked_get_sampling_rules - ): - self.rs = _AwsXRayRemoteSampler( + def test_create_remote_sampler_with_populated_resource(self): + self.rs = AwsXRayRemoteSampler( resource=Resource.create( { "service.name": "test-service-name", @@ -144,14 +145,8 @@ def test_create_remote_sampler_with_populated_resource( "test-cloud-platform", ) - @patch( - "opentelemetry.sdk.extension.aws.trace.sampler._aws_xray_sampling_client._AwsXRaySamplingClient.get_sampling_rules", - return_value=None, - ) - def test_create_remote_sampler_with_all_fields_populated( - self, mocked_get_sampling_rules - ): - self.rs = _AwsXRayRemoteSampler( + def test_create_remote_sampler_with_all_fields_populated(self): + self.rs = AwsXRayRemoteSampler( resource=Resource.create( { "service.name": "test-service-name", @@ -190,15 +185,295 @@ def test_create_remote_sampler_with_all_fields_populated( "test-cloud-platform", ) + @patch("requests.Session.post", side_effect=mocked_requests_get) + @patch( + "opentelemetry.sdk.extension.aws.trace.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", + 2, + ) + def test_update_sampling_rules_and_targets_with_pollers_and_should_sample( + self, mock_post=None + ): + self.rs = AwsXRayRemoteSampler( + resource=Resource.create( + { + "service.name": "test-service-name", + "cloud.platform": "test-cloud-platform", + } + ) + ) + self.assertEqual( + self.rs._root._root._InternalAwsXRayRemoteSampler__target_polling_interval, + 2, + ) + + time.sleep(1.0) + self.assertEqual( + self.rs._root._root._InternalAwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[ + 0 + ].sampling_rule.RuleName, + "test", + ) + self.assertEqual( + self.rs.should_sample( + None, 0, "name", attributes={"abc": "1234"} + ).decision, + Decision.DROP, + ) + + # wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s) + time.sleep(2.0) + self.assertEqual( + self.rs._root._root._InternalAwsXRayRemoteSampler__target_polling_interval, + 1000, + ) + self.assertEqual( + self.rs.should_sample( + None, 0, "name", attributes={"abc": "1234"} + ).decision, + Decision.RECORD_AND_SAMPLE, + ) + self.assertEqual( + self.rs.should_sample( + None, 0, "name", attributes={"abc": "1234"} + ).decision, + Decision.RECORD_AND_SAMPLE, + ) + self.assertEqual( + self.rs.should_sample( + None, 0, "name", attributes={"abc": "1234"} + ).decision, + Decision.RECORD_AND_SAMPLE, + ) + + @mark.skip( + reason="Uses sleep in test, which could be flaky. Remove this skip for validation locally." + ) + @patch("requests.Session.post", side_effect=mocked_requests_get) + @patch( + "opentelemetry.sdk.extension.aws.trace.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", + 3, + ) + def test_multithreading_with_large_reservoir_with_otel_sdk( + self, mock_post=None + ): + self.rs = AwsXRayRemoteSampler( + resource=Resource.create( + { + "service.name": "test-service-name", + "cloud.platform": "test-cloud-platform", + } + ) + ) + attributes = {"abc": "1234"} + + time.sleep(2.0) + self.assertEqual( + self.rs.should_sample( + None, 0, "name", attributes=attributes + ).decision, + Decision.DROP, + ) + + # wait 3 more seconds since targets polling was patched to 2 seconds (rather than 10s) + time.sleep(3.0) + + number_of_spans = 100 + thread_count = 1000 + sampled_array = [] + threads = [] + + for idx in range(0, thread_count): + sampled_array.append(0) + threads.append( + threading.Thread( + target=create_spans, + name="thread_" + str(idx), + daemon=True, + args=( + sampled_array, + idx, + attributes, + self.rs, + number_of_spans, + ), + ) + ) + threads[idx].start() + sum_sampled = 0 + + for idx in range(0, thread_count): + threads[idx].join() + sum_sampled += sampled_array[idx] + + test_rule_applier = self.rs._root._root._InternalAwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[ + 0 + ] + self.assertEqual( + test_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + 100000, + ) + self.assertEqual(sum_sampled, 100000) + + # pylint: disable=no-member + @mark.skip( + reason="Uses sleep in test, which could be flaky. Remove this skip for validation locally." + ) + @patch("requests.Session.post", side_effect=mocked_requests_get) + @patch( + "opentelemetry.sdk.extension.aws.trace.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", + 2, + ) @patch( - "opentelemetry.sdk.extension.aws.trace.sampler._aws_xray_sampling_client._AwsXRaySamplingClient.get_sampling_rules", - return_value=None, + "opentelemetry.sdk.extension.aws.trace.sampler.aws_xray_remote_sampler._Clock", + MockClock, ) - def test_get_description(self, mocked_get_sampling_rules) -> str: - self.rs: _AwsXRayRemoteSampler = _AwsXRayRemoteSampler( + def test_multithreading_with_some_reservoir_with_otel_sdk( + self, mock_post=None + ): + self.rs = AwsXRayRemoteSampler( + resource=Resource.create( + { + "service.name": "test-service-name", + "cloud.platform": "test-cloud-platform", + } + ) + ) + attributes = {"abc": "non-matching attribute value, use default rule"} + + # Using normal clock, finishing all thread jobs will take more than a second, + # which will eat up more than 1 second of reservoir. Using MockClock we can freeze time + # and pretend all thread jobs start and end at the exact same time, + # assume and test exactly 1 second of reservoir (100 quota) only + mock_clock: MockClock = self.rs._root._root._clock + + time.sleep(1.0) + mock_clock.add_time(1.0) + self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now()) + self.assertEqual( + self.rs.should_sample( + None, 0, "name", attributes=attributes + ).decision, + Decision.RECORD_AND_SAMPLE, + ) + + # wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s) + time.sleep(2.0) + mock_clock.add_time(2.0) + self.assertEqual(mock_clock.now(), self.rs._root._root._clock.now()) + + number_of_spans = 100 + thread_count = 1000 + sampled_array = [] + threads = [] + + for idx in range(0, thread_count): + sampled_array.append(0) + threads.append( + threading.Thread( + target=create_spans, + name="thread_" + str(idx), + daemon=True, + args=( + sampled_array, + idx, + attributes, + self.rs, + number_of_spans, + ), + ) + ) + threads[idx].start() + + sum_sampled = 0 + for idx in range(0, thread_count): + threads[idx].join() + sum_sampled += sampled_array[idx] + + default_rule_applier = self.rs._root._root._InternalAwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[ + 1 + ] + self.assertEqual( + default_rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + 100, + ) + self.assertEqual(sum_sampled, 100) + + def test_get_description(self) -> str: + self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler( resource=Resource.create({"service.name": "dummy_name"}) ) self.assertEqual( self.rs.get_description(), "AwsXRayRemoteSampler{root:ParentBased{root:_InternalAwsXRayRemoteSampler{remote sampling with AWS X-Ray},remoteParentSampled:AlwaysOnSampler,remoteParentNotSampled:AlwaysOffSampler,localParentSampled:AlwaysOnSampler,localParentNotSampled:AlwaysOffSampler}}", # noqa: E501 ) + + @patch("requests.Session.post", side_effect=mocked_requests_get) + def test_parent_based_xray_sampler_updates_statistics_once_for_one_parent_span_with_two_children( + self, mock_post=None + ): + self.rs: AwsXRayRemoteSampler = AwsXRayRemoteSampler( + resource=Resource.create( + {"service.name": "use-default-sample-all-rule"} + ) + ) + time.sleep(1.0) + + provider = TracerProvider(sampler=self.rs) + tracer: Tracer = provider.get_tracer("test_tracer_1") + + # child1 and child2 are child spans of root parent0 + # For AwsXRayRemoteSampler (ParentBased), expect only parent0 to update statistics + with tracer.start_as_current_span("parent0") as _: + with tracer.start_as_current_span("child1") as _: + pass + with tracer.start_as_current_span("child2") as _: + pass + default_rule_applier = self.rs._root._root._InternalAwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[ + 1 + ] + self.assertEqual( + default_rule_applier._SamplingRuleApplier__statistics.RequestCount, + 1, + ) + self.assertEqual( + default_rule_applier._SamplingRuleApplier__statistics.SampleCount, + 1, + ) + + @patch("requests.Session.post", side_effect=mocked_requests_get) + def test_non_parent_based_xray_sampler_updates_statistics_thrice_for_one_parent_span_with_two_children( + self, mock_post=None + ): + non_parent_based_xray_sampler: _InternalAwsXRayRemoteSampler = ( + _InternalAwsXRayRemoteSampler( + resource=Resource.create( + {"service.name": "use-default-sample-all-rule"} + ) + ) + ) + time.sleep(1.0) + + provider = TracerProvider(sampler=non_parent_based_xray_sampler) + tracer: Tracer = provider.get_tracer("test_tracer_2") + + # child1 and child2 are child spans of root parent0 + # For _InternalAwsXRayRemoteSampler (Non-ParentBased), expect all 3 spans to update statistics + with tracer.start_as_current_span("parent0") as _: + with tracer.start_as_current_span("child1") as _: + pass + with tracer.start_as_current_span("child2") as _: + pass + default_rule_applier = non_parent_based_xray_sampler._InternalAwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[ + 1 + ] + self.assertEqual( + default_rule_applier._SamplingRuleApplier__statistics.RequestCount, + 3, + ) + self.assertEqual( + default_rule_applier._SamplingRuleApplier__statistics.SampleCount, + 3, + ) + + non_parent_based_xray_sampler._rules_timer.cancel() + non_parent_based_xray_sampler._targets_timer.cancel() diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_fallback_sampler.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_fallback_sampler.py new file mode 100644 index 0000000000..de2c1af490 --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_fallback_sampler.py @@ -0,0 +1,129 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from unittest import TestCase + +# pylint: disable=no-name-in-module +from opentelemetry.sdk.extension.aws.trace.sampler._fallback_sampler import ( + _FallbackSampler, +) +from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, Decision + +from ._mock_clock import MockClock + + +class TestRateLimitingSampler(TestCase): + # pylint: disable=too-many-branches + def test_should_sample(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + sampler = _FallbackSampler(clock) + # Ignore testing TraceIdRatioBased + sampler._FallbackSampler__fixed_rate_sampler = ALWAYS_OFF + + sampler.should_sample(None, 1234, "name") + + # Essentially the same tests as test_rate_limiter.py + + # 0 seconds passed, 0 quota available + sampled = 0 + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + # 0.4 seconds passed, 0.4 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + # 0.8 seconds passed, 0.8 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + # 1.2 seconds passed, 1 quota consumed, 0 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 1) + + # 1.6 seconds passed, 0.4 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + # 2.0 seconds passed, 0.8 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + # 2.4 seconds passed, one more quota consumed, 0 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 1) + + # 30 seconds passed, only one quota can be consumed + sampled = 0 + clock.add_time(100) + for _ in range(0, 30): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 1) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_matcher.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_matcher.py new file mode 100644 index 0000000000..2dc5f4081d --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_matcher.py @@ -0,0 +1,87 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest import TestCase + +# pylint: disable=no-name-in-module +from opentelemetry.sdk.extension.aws.trace.sampler._matcher import _Matcher +from opentelemetry.util.types import Attributes + + +class TestMatcher(TestCase): + def test_wild_card_match(self): + test_cases = [ + [None, "*"], + ["", "*"], + ["HelloWorld", "*"], + ["HelloWorld", "HelloWorld"], + ["HelloWorld", "Hello*"], + ["HelloWorld", "*World"], + ["HelloWorld", "?ello*"], + ["HelloWorld", "Hell?W*d"], + ["Hello.World", "*.World"], + ["Bye.World", "*.World"], + ] + for test_case in test_cases: + self.assertTrue( + _Matcher.wild_card_match( + text=test_case[0], pattern=test_case[1] + ) + ) + + def test_wild_card_not_match(self): + test_cases = [[None, "Hello*"], ["HelloWorld", None]] + for test_case in test_cases: + self.assertFalse( + _Matcher.wild_card_match( + text=test_case[0], pattern=test_case[1] + ) + ) + + def test_attribute_matching(self): + attributes: Attributes = { + "dog": "bark", + "cat": "meow", + "cow": "mooo", + } + rule_attributes = { + "dog": "bar?", + "cow": "mooo", + } + + self.assertTrue(_Matcher.attribute_match(attributes, rule_attributes)) + + def test_attribute_matching_without_rule_attributes(self): + attributes = { + "dog": "bark", + "cat": "meow", + "cow": "mooo", + } + rule_attributes = {} + print("LENGTH %s", len(rule_attributes)) + + self.assertTrue(_Matcher.attribute_match(attributes, rule_attributes)) + + def test_attribute_matching_without_span_attributes(self): + attributes = {} + rule_attributes = { + "dog": "bar?", + "cow": "mooo", + } + + self.assertFalse(_Matcher.attribute_match(attributes, rule_attributes)) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rate_limiter.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rate_limiter.py new file mode 100644 index 0000000000..bbb129434c --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rate_limiter.py @@ -0,0 +1,54 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from unittest import TestCase + +# pylint: disable=no-name-in-module +from opentelemetry.sdk.extension.aws.trace.sampler._rate_limiter import ( + _RateLimiter, +) + +from ._mock_clock import MockClock + + +class TestRateLimiter(TestCase): + def test_try_spend(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + rate_limiter = _RateLimiter(1, 30, clock) + + spent = 0 + for _ in range(0, 100): + if rate_limiter.try_spend(1): + spent += 1 + self.assertEqual(spent, 0) + + spent = 0 + clock.add_time(0.5) + for _ in range(0, 100): + if rate_limiter.try_spend(1): + spent += 1 + self.assertEqual(spent, 15) + + spent = 0 + clock.add_time(1000) + for _ in range(0, 100): + if rate_limiter.try_spend(1): + spent += 1 + self.assertEqual(spent, 30) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rate_limiting_sampler.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rate_limiting_sampler.py new file mode 100644 index 0000000000..5a36b66302 --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rate_limiting_sampler.py @@ -0,0 +1,129 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from unittest import TestCase + +# pylint: disable=no-name-in-module +from opentelemetry.sdk.extension.aws.trace.sampler._rate_limiting_sampler import ( + _RateLimitingSampler, +) +from opentelemetry.sdk.trace.sampling import Decision + +from ._mock_clock import MockClock + + +class TestRateLimitingSampler(TestCase): + def test_should_sample(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + sampler = _RateLimitingSampler(30, clock) + + # Essentially the same tests as test_rate_limiter.py + sampled = 0 + for _ in range(0, 100): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + sampled = 0 + clock.add_time(0.5) + for _ in range(0, 100): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 15) + + sampled = 0 + clock.add_time(1.0) + for _ in range(0, 100): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 30) + + sampled = 0 + clock.add_time(2.5) + for _ in range(0, 100): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 30) + + sampled = 0 + clock.add_time(1000) + for _ in range(0, 100): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 30) + + def test_should_sample_with_quota_of_one(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + sampler = _RateLimitingSampler(1, clock) + + sampled = 0 + for _ in range(0, 50): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + sampled = 0 + clock.add_time(0.5) + for _ in range(0, 50): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 0) + + sampled = 0 + clock.add_time(0.5) + for _ in range(0, 50): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 1) + + sampled = 0 + clock.add_time(1000) + for _ in range(0, 50): + if ( + sampler.should_sample(None, 1234, "name").decision + != Decision.DROP + ): + sampled += 1 + self.assertEqual(sampled, 1) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rule_cache.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rule_cache.py new file mode 100644 index 0000000000..3380ff52fa --- /dev/null +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_rule_cache.py @@ -0,0 +1,335 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +from threading import Lock +from unittest import TestCase + +# pylint: disable=no-name-in-module +from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock +from opentelemetry.sdk.extension.aws.trace.sampler._rule_cache import ( + CACHE_TTL_SECONDS, + _RuleCache, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_rule import ( + _SamplingRule, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_rule_applier import ( + _SamplingRuleApplier, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_statistics_document import ( + _SamplingStatisticsDocument, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_target import ( + _SamplingTargetResponse, +) +from opentelemetry.sdk.resources import Resource + +from ._mock_clock import MockClock + +CLIENT_ID = "12345678901234567890abcd" + + +# pylint: disable=no-member disable=C0103 +class TestRuleCache(TestCase): + def test_cache_update_rules_and_sorts_rules(self): + cache = _RuleCache(None, None, CLIENT_ID, _Clock(), Lock()) + self.assertTrue(len(cache._RuleCache__rule_appliers) == 0) + + rule1 = _SamplingRule( + Priority=200, RuleName="only_one_rule", Version=1 + ) + rules = [rule1] + cache.update_sampling_rules(rules) + self.assertTrue(len(cache._RuleCache__rule_appliers) == 1) + + rule1 = _SamplingRule(Priority=200, RuleName="abcdef", Version=1) + rule2 = _SamplingRule(Priority=100, RuleName="abc", Version=1) + rule3 = _SamplingRule(Priority=100, RuleName="Abc", Version=1) + rule4 = _SamplingRule(Priority=100, RuleName="ab", Version=1) + rule5 = _SamplingRule(Priority=100, RuleName="A", Version=1) + rule6 = _SamplingRule(Priority=1, RuleName="abcdef", Version=1) + rules = [rule1, rule2, rule3, rule4, rule5, rule6] + cache.update_sampling_rules(rules) + + self.assertTrue(len(cache._RuleCache__rule_appliers) == 6) + self.assertEqual( + cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, "abcdef" + ) + self.assertEqual( + cache._RuleCache__rule_appliers[1].sampling_rule.RuleName, "A" + ) + self.assertEqual( + cache._RuleCache__rule_appliers[2].sampling_rule.RuleName, "Abc" + ) + self.assertEqual( + cache._RuleCache__rule_appliers[3].sampling_rule.RuleName, "ab" + ) + self.assertEqual( + cache._RuleCache__rule_appliers[4].sampling_rule.RuleName, "abc" + ) + self.assertEqual( + cache._RuleCache__rule_appliers[5].sampling_rule.RuleName, "abcdef" + ) + + def test_rule_cache_expiration_logic(self): + dt = datetime + cache = _RuleCache( + None, Resource.get_empty(), CLIENT_ID, _Clock(), Lock() + ) + self.assertFalse(cache.expired()) + cache._last_modified = dt.datetime.now() - dt.timedelta( + seconds=CACHE_TTL_SECONDS - 5 + ) + self.assertFalse(cache.expired()) + cache._last_modified = dt.datetime.now() - dt.timedelta( + seconds=CACHE_TTL_SECONDS + 1 + ) + self.assertTrue(cache.expired()) + + def test_update_cache_with_only_one_rule_changed(self): + cache = _RuleCache( + None, Resource.get_empty(), CLIENT_ID, _Clock(), Lock() + ) + rule1 = _SamplingRule(Priority=1, RuleName="abcdef", Version=1) + rule2 = _SamplingRule(Priority=10, RuleName="ab", Version=1) + rule3 = _SamplingRule(Priority=100, RuleName="Abc", Version=1) + rules = [rule1, rule2, rule3] + cache.update_sampling_rules(rules) + + cache_rules_copy = cache._RuleCache__rule_appliers + + new_rule3 = _SamplingRule(Priority=5, RuleName="Abc", Version=1) + rules = [rule1, rule2, new_rule3] + cache.update_sampling_rules(rules) + + self.assertTrue(len(cache._RuleCache__rule_appliers) == 3) + self.assertEqual( + cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, "abcdef" + ) + self.assertEqual( + cache._RuleCache__rule_appliers[1].sampling_rule.RuleName, "Abc" + ) + self.assertEqual( + cache._RuleCache__rule_appliers[2].sampling_rule.RuleName, "ab" + ) + + # Compare that only rule1 and rule2 objects have not changed due to new_rule3 even after sorting + self.assertTrue( + cache_rules_copy[0] is cache._RuleCache__rule_appliers[0] + ) + self.assertTrue( + cache_rules_copy[1] is cache._RuleCache__rule_appliers[2] + ) + self.assertTrue( + cache_rules_copy[2] is not cache._RuleCache__rule_appliers[1] + ) + + def test_update_rules_removes_older_rule(self): + cache = _RuleCache(None, None, CLIENT_ID, _Clock(), Lock()) + self.assertTrue(len(cache._RuleCache__rule_appliers) == 0) + + rule1 = _SamplingRule(Priority=200, RuleName="first_rule", Version=1) + rules = [rule1] + cache.update_sampling_rules(rules) + self.assertTrue(len(cache._RuleCache__rule_appliers) == 1) + self.assertEqual( + cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, + "first_rule", + ) + + rule1 = _SamplingRule(Priority=200, RuleName="second_rule", Version=1) + rules = [rule1] + cache.update_sampling_rules(rules) + self.assertTrue(len(cache._RuleCache__rule_appliers) == 1) + self.assertEqual( + cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, + "second_rule", + ) + + def test_update_sampling_targets(self): + sampling_rule_1 = _SamplingRule( + Attributes={}, + FixedRate=0.05, + HTTPMethod="*", + Host="*", + Priority=10000, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/default", + RuleName="default", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + sampling_rule_2 = _SamplingRule( + Attributes={}, + FixedRate=0.20, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=10, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + + rule_cache = _RuleCache( + Resource.get_empty(), None, "", mock_clock, Lock() + ) + rule_cache.update_sampling_rules([sampling_rule_1, sampling_rule_2]) + + # quota should be 1 because of borrowing=true until targets are updated + rule_applier_0 = rule_cache._RuleCache__rule_appliers[0] + self.assertEqual( + rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + 1, + ) + self.assertEqual( + rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, + sampling_rule_2.FixedRate, + ) + + rule_applier_1 = rule_cache._RuleCache__rule_appliers[1] + self.assertEqual( + rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + 1, + ) + self.assertEqual( + rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, + sampling_rule_1.FixedRate, + ) + + target_1 = { + "FixedRate": 0.05, + "Interval": 15, + "ReservoirQuota": 1, + "ReservoirQuotaTTL": mock_clock.now().timestamp() + 10, + "RuleName": "default", + } + target_2 = { + "FixedRate": 0.15, + "Interval": 12, + "ReservoirQuota": 5, + "ReservoirQuotaTTL": mock_clock.now().timestamp() + 10, + "RuleName": "test", + } + target_3 = { + "FixedRate": 0.15, + "Interval": 3, + "ReservoirQuota": 5, + "ReservoirQuotaTTL": mock_clock.now().timestamp() + 10, + "RuleName": "associated rule does not exist", + } + target_response = _SamplingTargetResponse( + mock_clock.now().timestamp() - 10, + [target_1, target_2, target_3], + [], + ) + refresh_rules, min_polling_interval = ( + rule_cache.update_sampling_targets(target_response) + ) + self.assertFalse(refresh_rules) + # target_3 Interval is ignored since it's not associated with a Rule Applier + self.assertEqual(min_polling_interval, target_2["Interval"]) + + # still only 2 rule appliers should exist if for some reason 3 targets are obtained + self.assertEqual(len(rule_cache._RuleCache__rule_appliers), 2) + + # borrowing=false, use quota from targets + rule_applier_0 = rule_cache._RuleCache__rule_appliers[0] + self.assertEqual( + rule_applier_0._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + target_2["ReservoirQuota"], + ) + self.assertEqual( + rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._rate, + target_2["FixedRate"], + ) + + rule_applier_1 = rule_cache._RuleCache__rule_appliers[1] + self.assertEqual( + rule_applier_1._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + target_1["ReservoirQuota"], + ) + self.assertEqual( + rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._rate, + target_1["FixedRate"], + ) + + # Test target response modified after Rule cache's last modified date + target_response.LastRuleModification = mock_clock.now().timestamp() + 1 + refresh_rules, _ = rule_cache.update_sampling_targets(target_response) + self.assertTrue(refresh_rules) + + # pylint:disable=C0103 + def test_get_all_statistics(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + rule_applier_1 = _SamplingRuleApplier( + _SamplingRule(RuleName="test"), CLIENT_ID, mock_clock + ) + rule_applier_2 = _SamplingRuleApplier( + _SamplingRule(RuleName="default"), CLIENT_ID, mock_clock + ) + + rule_applier_1._SamplingRuleApplier__statistics = ( + _SamplingStatisticsDocument(CLIENT_ID, "test", 4, 2, 2) + ) + rule_applier_2._SamplingRuleApplier__statistics = ( + _SamplingStatisticsDocument(CLIENT_ID, "default", 5, 5, 5) + ) + + rule_cache = _RuleCache( + Resource.get_empty(), None, "", mock_clock, Lock() + ) + rule_cache._RuleCache__rule_appliers = [rule_applier_1, rule_applier_2] + + mock_clock.add_time(10) + statistics = rule_cache.get_all_statistics() + + self.assertEqual( + statistics, + [ + { + "ClientID": CLIENT_ID, + "RuleName": "test", + "Timestamp": mock_clock.now().timestamp(), + "RequestCount": 4, + "BorrowCount": 2, + "SampleCount": 2, + }, + { + "ClientID": CLIENT_ID, + "RuleName": "default", + "Timestamp": mock_clock.now().timestamp(), + "RequestCount": 5, + "BorrowCount": 5, + "SampleCount": 5, + }, + ], + ) diff --git a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_sampling_rule_applier.py b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_sampling_rule_applier.py index 1097301b9a..d8f4d31e58 100644 --- a/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_sampling_rule_applier.py +++ b/sdk-extension/opentelemetry-sdk-extension-aws/tests/trace/sampler/test_sampling_rule_applier.py @@ -12,8 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Includes work from: +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime +import json import os from unittest import TestCase +from unittest.mock import patch + +# pylint: disable=no-name-in-module +from opentelemetry.sdk.extension.aws.trace.sampler._clock import _Clock +from opentelemetry.sdk.extension.aws.trace.sampler._rate_limiting_sampler import ( + _RateLimitingSampler, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_rule import ( + _SamplingRule, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_rule_applier import ( + _SamplingRuleApplier, +) +from opentelemetry.sdk.extension.aws.trace.sampler._sampling_target import ( + _SamplingTarget, +) +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import ( + Decision, + SamplingResult, + TraceIdRatioBased, +) +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.util.types import Attributes + +from ._mock_clock import MockClock TEST_DIR = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = os.path.join(TEST_DIR, "data") @@ -23,4 +56,527 @@ # pylint: disable=no-member class TestSamplingRuleApplier(TestCase): - pass + def test_applier_attribute_matching_from_xray_response(self): + default_rule = None + with open( + f"{DATA_DIR}/get-sampling-rules-response-sample-2.json", + encoding="UTF-8", + ) as file: + sample_response = json.load(file) + print(sample_response) + all_rules = sample_response["SamplingRuleRecords"] + default_rule = _SamplingRule(**all_rules[0]["SamplingRule"]) + file.close() + + res = Resource.create( + attributes={ + ResourceAttributes.SERVICE_NAME: "test_service_name", + ResourceAttributes.CLOUD_PLATFORM: "test_cloud_platform", + } + ) + attr: Attributes = { + SpanAttributes.URL_PATH: "target", + SpanAttributes.HTTP_REQUEST_METHOD: "method", + SpanAttributes.URL_FULL: "url", + SpanAttributes.SERVER_ADDRESS: "host", + "foo": "bar", + "abc": "1234", + } + + rule_applier = _SamplingRuleApplier(default_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(res, attr)) + + # Test again using deprecated Span Attributes + attr: Attributes = { + SpanAttributes.HTTP_TARGET: "target", + SpanAttributes.HTTP_METHOD: "method", + SpanAttributes.HTTP_URL: "url", + SpanAttributes.HTTP_HOST: "host", + "foo": "bar", + "abc": "1234", + } + self.assertTrue(rule_applier.matches(res, attr)) + + def test_applier_matches_with_all_attributes(self): + sampling_rule = _SamplingRule( + Attributes={"abc": "123", "def": "4?6", "ghi": "*89"}, + FixedRate=0.11, + HTTPMethod="GET", + Host="localhost", + Priority=20, + ReservoirSize=1, + # Note that ResourceARN is usually only able to be "*" + # See: https://docs.aws.amazon.com/xray/latest/devguide/xray-console-sampling.html#xray-console-sampling-options # noqa: E501 + ResourceARN="arn:aws:lambda:us-west-2:123456789012:function:my-function", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="myServiceName", + ServiceType="AWS::Lambda::Function", + URLPath="/helloworld", + Version=1, + ) + + attributes: Attributes = { + "server.address": "localhost", + SpanAttributes.HTTP_REQUEST_METHOD: "GET", + SpanAttributes.CLOUD_RESOURCE_ID: "arn:aws:lambda:us-west-2:123456789012:function:my-function", + "url.full": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + # Test that deprecated attributes are not used in matching when above new attributes are set + "http.host": "deprecated and will not be used in matching", + SpanAttributes.HTTP_METHOD: "deprecated and will not be used in matching", + "faas.id": "deprecated and will not be used in matching", + "http.url": "deprecated and will not be used in matching", + } + + resource_attr = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_lambda", # CloudPlatformValues.AWS_LAMBDA.value + } + resource = Resource.create(attributes=resource_attr) + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(resource, attributes)) + + # Test using deprecated Span Attributes + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_METHOD: "GET", + "faas.id": "arn:aws:lambda:us-west-2:123456789012:function:my-function", + "http.url": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + } + self.assertTrue(rule_applier.matches(resource, attributes)) + + def test_applier_wild_card_attributes_matches_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={ + "attr1": "*", + "attr2": "*", + "attr3": "HelloWorld", + "attr4": "Hello*", + "attr5": "*World", + "attr6": "?ello*", + "attr7": "Hell?W*d", + "attr8": "*.World", + "attr9": "*.World", + }, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + "attr1": "", + "attr2": "HelloWorld", + "attr3": "HelloWorld", + "attr4": "HelloWorld", + "attr5": "HelloWorld", + "attr6": "HelloWorld", + "attr7": "HelloWorld", + "attr8": "Hello.World", + "attr9": "Bye.World", + } + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) + + def test_applier_wild_card_attributes_matches_http_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + SpanAttributes.SERVER_ADDRESS: "localhost", + SpanAttributes.HTTP_REQUEST_METHOD: "GET", + SpanAttributes.URL_FULL: "http://127.0.0.1:5000/helloworld", + } + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) + + # Test using deprecated Span Attributes + attributes: Attributes = { + SpanAttributes.HTTP_HOST: "localhost", + SpanAttributes.HTTP_METHOD: "GET", + SpanAttributes.HTTP_URL: "http://127.0.0.1:5000/helloworld", + } + + self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) + + def test_applier_wild_card_attributes_matches_with_empty_attributes(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = {} + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_ec2", + } + resource = Resource.create(attributes=resource_attr) + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(resource, attributes)) + self.assertTrue(rule_applier.matches(resource, None)) + self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) + self.assertTrue(rule_applier.matches(Resource.get_empty(), None)) + self.assertTrue(rule_applier.matches(None, attributes)) + self.assertTrue(rule_applier.matches(None, None)) + + def test_applier_does_not_match_without_http_target(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="/helloworld", + Version=1, + ) + + attributes: Attributes = {} + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_ec2", + } + resource = Resource.create(attributes=resource_attr) + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertFalse(rule_applier.matches(resource, attributes)) + + def test_applier_matches_with_http_target(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="/hello*", + Version=1, + ) + + attributes: Attributes = {SpanAttributes.URL_PATH: "/helloworld"} + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_ec2", + } + resource = Resource.create(attributes=resource_attr) + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(resource, attributes)) + + # Test again using deprecated Span Attributes + attributes: Attributes = {SpanAttributes.HTTP_TARGET: "/helloworld"} + self.assertTrue(rule_applier.matches(resource, attributes)) + + def test_applier_matches_with_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={"abc": "123", "def": "456", "ghi": "789"}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + "server.address": "localhost", + SpanAttributes.HTTP_REQUEST_METHOD: "GET", + "url.full": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + } + + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_eks", + } + resource = Resource.create(attributes=resource_attr) + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(resource, attributes)) + + # Test again using deprecated Span Attributes + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_METHOD: "GET", + "http.url": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + } + self.assertTrue(rule_applier.matches(resource, attributes)) + + def test_applier_does_not_match_with_less_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={"abc": "123", "def": "456", "ghi": "789"}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_REQUEST_METHOD: "GET", + "url.full": "http://127.0.0.1:5000/helloworld", + "abc": "123", + } + + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_eks", + } + resource = Resource.create(attributes=resource_attr) + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertFalse(rule_applier.matches(resource, attributes)) + + def test_update_sampling_applier(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + + rule_applier = _SamplingRuleApplier( + sampling_rule, CLIENT_ID, mock_clock + ) + + self.assertEqual( + rule_applier._SamplingRuleApplier__fixed_rate_sampler._rate, 0.11 + ) + self.assertEqual( + rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + 1, + ) + self.assertEqual( + rule_applier._SamplingRuleApplier__reservoir_expiry, + datetime.datetime.max, + ) + + target = _SamplingTarget( + FixedRate=1.0, + Interval=10, + ReservoirQuota=30, + ReservoirQuotaTTL=1707764006.0, + RuleName="test", + ) + # Update rule applier + rule_applier = rule_applier.with_target(target) + + time_now = datetime.datetime.fromtimestamp(target.ReservoirQuotaTTL) + mock_clock.set_time(time_now) + + self.assertEqual( + rule_applier._SamplingRuleApplier__fixed_rate_sampler._rate, 1.0 + ) + self.assertEqual( + rule_applier._SamplingRuleApplier__reservoir_sampler._RateLimitingSampler__reservoir._quota, + 30, + ) + self.assertEqual( + rule_applier._SamplingRuleApplier__reservoir_expiry, + mock_clock.now(), + ) + + @staticmethod + def fake_reservoir_do_sample(*args, **kwargs): + return SamplingResult( + decision=Decision.RECORD_AND_SAMPLE, + attributes=None, + trace_state=None, + ) + + @staticmethod + def fake_ratio_do_sample(*args, **kwargs): + return SamplingResult( + decision=Decision.RECORD_AND_SAMPLE, + attributes=None, + trace_state=None, + ) + + @staticmethod + def fake_ratio_do_not_sample(*args, **kwargs): + return SamplingResult( + decision=Decision.RECORD_AND_SAMPLE, + attributes=None, + trace_state=None, + ) + + @patch.object(TraceIdRatioBased, "should_sample", fake_ratio_do_sample) + @patch.object( + _RateLimitingSampler, "should_sample", fake_reservoir_do_sample + ) + def test_populate_and_get_then_reset_statistics(self): + mock_clock = MockClock() + rule_applier = _SamplingRuleApplier( + _SamplingRule(RuleName="test", ReservoirSize=10), + CLIENT_ID, + mock_clock, + ) + rule_applier.should_sample(None, 0, "name") + rule_applier.should_sample(None, 0, "name") + rule_applier.should_sample(None, 0, "name") + + statistics = rule_applier.get_then_reset_statistics() + + self.assertEqual(statistics["ClientID"], CLIENT_ID) + self.assertEqual(statistics["RuleName"], "test") + self.assertEqual(statistics["Timestamp"], mock_clock.now().timestamp()) + self.assertEqual(statistics["RequestCount"], 3) + self.assertEqual(statistics["BorrowCount"], 3) + self.assertEqual(statistics["SampleCount"], 3) + self.assertEqual( + rule_applier._SamplingRuleApplier__statistics.RequestCount, 0 + ) + self.assertEqual( + rule_applier._SamplingRuleApplier__statistics.BorrowCount, 0 + ) + self.assertEqual( + rule_applier._SamplingRuleApplier__statistics.SampleCount, 0 + ) + + def test_should_sample_logic_from_reservoir(self): + reservoir_size = 10 + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + rule_applier = _SamplingRuleApplier( + _SamplingRule( + RuleName="test", ReservoirSize=reservoir_size, FixedRate=0.0 + ), + CLIENT_ID, + mock_clock, + ) + + mock_clock.add_time(seconds=2.0) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if ( + rule_applier.should_sample(None, 0, "name").decision + != Decision.DROP + ): + sampled_count += 1 + self.assertEqual(sampled_count, 1) + # borrow means only 1 sampled + + target = _SamplingTarget( + FixedRate=0.0, + Interval=10, + ReservoirQuota=10, + ReservoirQuotaTTL=mock_clock.now().timestamp() + 10, + RuleName="test", + ) + rule_applier = rule_applier.with_target(target) + + # Use only 100% of quota (10 out of 10), even if 2 seconds have passed + mock_clock.add_time(seconds=2.0) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if ( + rule_applier.should_sample(None, 0, "name").decision + != Decision.DROP + ): + sampled_count += 1 + self.assertEqual(sampled_count, reservoir_size) + + # Use only 50% of quota (5 out of 10) + mock_clock.add_time(seconds=0.5) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if ( + rule_applier.should_sample(None, 0, "name").decision + != Decision.DROP + ): + sampled_count += 1 + self.assertEqual(sampled_count, 5) + + # Expired at 10s, do not sample + mock_clock.add_time(seconds=7.5) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if ( + rule_applier.should_sample(None, 0, "name").decision + != Decision.DROP + ): + sampled_count += 1 + self.assertEqual(sampled_count, 0)