Skip to content

Commit 232ca3e

Browse files
authored
XRay Sampler Enhancements - Use Request 'session', 'with' locks, add jitter for initial Targets Poll (#68)
*Issue #, if available:* N/A *Description of changes:* - Use request Session object to be reused for getting sampling rules/targets - Use `with` locks to be consistent with lock usage - Add jitter for initial Targets Poll By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent aba9ce3 commit 232ca3e

File tree

5 files changed

+34
-36
lines changed

5 files changed

+34
-36
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ def __init__(self, endpoint: str = None, log_level: str = None):
2222
self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules"
2323
self.__get_sampling_targets_endpoint = endpoint + "/SamplingTargets"
2424

25+
self.__session = requests.Session()
26+
2527
def get_sampling_rules(self) -> [_SamplingRule]:
2628
sampling_rules = []
2729
headers = {"content-type": "application/json"}
2830

2931
try:
30-
xray_response = requests.post(url=self.__get_sampling_rules_endpoint, headers=headers, timeout=20)
32+
xray_response = self.__session.post(url=self.__get_sampling_rules_endpoint, headers=headers, timeout=20)
3133
if xray_response is None:
3234
_logger.error("GetSamplingRules response is None")
3335
return []
@@ -60,7 +62,7 @@ def get_sampling_targets(self, statistics: [dict]) -> _SamplingTargetResponse:
6062
)
6163
headers = {"content-type": "application/json"}
6264
try:
63-
xray_response = requests.post(
65+
xray_response = self.__session.post(
6466
url=self.__get_sampling_targets_endpoint,
6567
headers=headers,
6668
timeout=20,

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,22 @@ def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
7575
continue
7676
temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule, self.__client_id, self._clock))
7777

78-
self.__cache_lock.acquire()
79-
80-
# map list of rule appliers by each applier's sampling_rule name
81-
rule_applier_map: Dict[str, _SamplingRuleApplier] = {
82-
applier.sampling_rule.RuleName: applier for applier in self.__rule_appliers
83-
}
84-
85-
# If a sampling rule has not changed, keep its respective applier in the cache.
86-
new_applier: _SamplingRuleApplier
87-
for index, new_applier in enumerate(temp_rule_appliers):
88-
rule_name_to_check = new_applier.sampling_rule.RuleName
89-
if rule_name_to_check in rule_applier_map:
90-
old_applier = rule_applier_map[rule_name_to_check]
91-
if new_applier.sampling_rule == old_applier.sampling_rule:
92-
temp_rule_appliers[index] = old_applier
93-
self.__rule_appliers = temp_rule_appliers
94-
self._last_modified = self._clock.now()
95-
96-
self.__cache_lock.release()
78+
with self.__cache_lock:
79+
# map list of rule appliers by each applier's sampling_rule name
80+
rule_applier_map: Dict[str, _SamplingRuleApplier] = {
81+
applier.sampling_rule.RuleName: applier for applier in self.__rule_appliers
82+
}
83+
84+
# If a sampling rule has not changed, keep its respective applier in the cache.
85+
new_applier: _SamplingRuleApplier
86+
for index, new_applier in enumerate(temp_rule_appliers):
87+
rule_name_to_check = new_applier.sampling_rule.RuleName
88+
if rule_name_to_check in rule_applier_map:
89+
old_applier = rule_applier_map[rule_name_to_check]
90+
if new_applier.sampling_rule == old_applier.sampling_rule:
91+
temp_rule_appliers[index] = old_applier
92+
self.__rule_appliers = temp_rule_appliers
93+
self._last_modified = self._clock.now()
9794

9895
def update_sampling_targets(self, sampling_targets_response: _SamplingTargetResponse) -> (bool, int):
9996
targets: [_SamplingTarget] = sampling_targets_response.SamplingTargetDocuments
@@ -135,8 +132,5 @@ def get_all_statistics(self) -> [dict]:
135132
return all_statistics
136133

137134
def expired(self) -> bool:
138-
self.__cache_lock.acquire()
139-
try:
135+
with self.__cache_lock:
140136
return self._clock.now() > self._last_modified + self._clock.time_delta(seconds=CACHE_TTL_SECONDS)
141-
finally:
142-
self.__cache_lock.release()

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def __init__(
8383
self._rules_timer.start()
8484

8585
# set up the target poller to go off once after the default interval. Subsequent polls may use new intervals.
86-
self._targets_timer = Timer(DEFAULT_TARGET_POLLING_INTERVAL_SECONDS, self.__start_sampling_target_poller)
86+
self._targets_timer = Timer(
87+
self.__target_polling_interval + self.__target_polling_jitter, self.__start_sampling_target_poller
88+
)
8789
self._targets_timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed
8890
self._targets_timer.start()
8991

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test_create_remote_sampler_with_all_fields_populated(self):
8686
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
8787
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
8888

89-
@patch("requests.post", side_effect=mocked_requests_get)
89+
@patch("requests.Session.post", side_effect=mocked_requests_get)
9090
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 2)
9191
def test_update_sampling_rules_and_targets_with_pollers_and_should_sample(self, mock_post=None):
9292
rs = AwsXRayRemoteSampler(
@@ -113,7 +113,7 @@ def test_update_sampling_rules_and_targets_with_pollers_and_should_sample(self,
113113
rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.RECORD_AND_SAMPLE
114114
)
115115

116-
@patch("requests.post", side_effect=mocked_requests_get)
116+
@patch("requests.Session.post", side_effect=mocked_requests_get)
117117
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 3)
118118
def test_multithreading_with_large_reservoir_with_otel_sdk(self, mock_post=None):
119119
rs = AwsXRayRemoteSampler(
@@ -157,7 +157,7 @@ def test_multithreading_with_large_reservoir_with_otel_sdk(self, mock_post=None)
157157
self.assertEqual(sum_sampled, 100000)
158158

159159
# pylint: disable=no-member
160-
@patch("requests.post", side_effect=mocked_requests_get)
160+
@patch("requests.Session.post", side_effect=mocked_requests_get)
161161
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 2)
162162
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._Clock", MockClock)
163163
def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None):

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,30 @@
1616

1717

1818
class TestAwsXRaySamplingClient(TestCase):
19-
@patch("requests.post")
19+
@patch("requests.Session.post")
2020
def test_get_no_sampling_rules(self, mock_post=None):
2121
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": []}})
2222
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
2323
sampling_rules = client.get_sampling_rules()
2424
self.assertTrue(len(sampling_rules) == 0)
2525

26-
@patch("requests.post")
26+
@patch("requests.Session.post")
2727
def test_get_invalid_responses(self, mock_post=None):
2828
mock_post.return_value.configure_mock(**{"json.return_value": {}})
2929
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
3030
with self.assertLogs(_logger, level="ERROR"):
3131
sampling_rules = client.get_sampling_rules()
3232
self.assertTrue(len(sampling_rules) == 0)
3333

34-
@patch("requests.post")
34+
@patch("requests.Session.post")
3535
def test_get_sampling_rule_missing_in_records(self, mock_post=None):
3636
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{}]}})
3737
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
3838
with self.assertLogs(_logger, level="ERROR"):
3939
sampling_rules = client.get_sampling_rules()
4040
self.assertTrue(len(sampling_rules) == 0)
4141

42-
@patch("requests.post")
42+
@patch("requests.Session.post")
4343
def test_default_values_used_when_missing_properties_in_sampling_rule(self, mock_post=None):
4444
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{"SamplingRule": {}}]}})
4545
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
@@ -61,7 +61,7 @@ def test_default_values_used_when_missing_properties_in_sampling_rule(self, mock
6161
self.assertEqual(sampling_rule.URLPath, "")
6262
self.assertEqual(sampling_rule.Version, 0)
6363

64-
@patch("requests.post")
64+
@patch("requests.Session.post")
6565
def test_get_correct_number_of_sampling_rules(self, mock_post=None):
6666
sampling_records = []
6767
with open(f"{DATA_DIR}/get-sampling-rules-response-sample.json", encoding="UTF-8") as file:
@@ -104,7 +104,7 @@ def validate_match_sampling_rules_properties_with_records(self, sampling_rules,
104104
self.assertIsNotNone(sampling_rule.Version)
105105
self.assertEqual(sampling_rule.Version, sampling_record["SamplingRule"]["Version"])
106106

107-
@patch("requests.post")
107+
@patch("requests.Session.post")
108108
def test_get_sampling_targets(self, mock_post=None):
109109
with open(f"{DATA_DIR}/get-sampling-targets-response-sample.json", encoding="UTF-8") as file:
110110
sample_response = json.load(file)
@@ -116,7 +116,7 @@ def test_get_sampling_targets(self, mock_post=None):
116116
self.assertEqual(len(sampling_targets_response.UnprocessedStatistics), 0)
117117
self.assertEqual(sampling_targets_response.LastRuleModification, 1707551387.0)
118118

119-
@patch("requests.post")
119+
@patch("requests.Session.post")
120120
def test_get_invalid_sampling_targets(self, mock_post=None):
121121
mock_post.return_value.configure_mock(
122122
**{

0 commit comments

Comments
 (0)