diff --git a/CHANGELOG.md b/CHANGELOG.md index bac47c6bc01..08010131788 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Allow passing a custom random number generator to `trace.RandomIdGenerator`. + ([#4571](https://github.com/open-telemetry/opentelemetry-python/pull/4571)) - typecheck: add sdk/resources and drop mypy ([#4578](https://github.com/open-telemetry/opentelemetry-python/pull/4578)) - Refactor `BatchLogRecordProcessor` to simplify code and make the control flow more diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py index cd1f89bcde2..93974595734 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/id_generator.py @@ -14,6 +14,8 @@ import abc import random +from random import Random +from typing import Optional from opentelemetry import trace @@ -45,16 +47,29 @@ def generate_trace_id(self) -> int: class RandomIdGenerator(IdGenerator): """The default ID generator for TracerProvider which randomly generates all bits when generating IDs. + + Args: + rng: A random number generator. Defaults to the global random instance. + It is recommended to use a fresh `Random()` instead of the default + to avoid potential conflicts with the global random instance + (duplicate ids across multiple processes when a constant global + random seed is set). In case of a custom implementation, it should + be uniform, as some samplers rely on this randomness to make sampling decisions. """ + def __init__(self, rng: Optional[Random] = None) -> None: + if rng is None: + rng = random # Just a hack to preserve backward compatibility, otherwise it does not quite match the type hint. + self._rng = rng + def generate_span_id(self) -> int: - span_id = random.getrandbits(64) + span_id = self._rng.getrandbits(64) while span_id == trace.INVALID_SPAN_ID: - span_id = random.getrandbits(64) + span_id = self._rng.getrandbits(64) return span_id def generate_trace_id(self) -> int: - trace_id = random.getrandbits(128) + trace_id = self._rng.getrandbits(128) while trace_id == trace.INVALID_TRACE_ID: - trace_id = random.getrandbits(128) + trace_id = self._rng.getrandbits(128) return trace_id diff --git a/opentelemetry-sdk/tests/trace/test_trace.py b/opentelemetry-sdk/tests/trace/test_trace.py index 7b23c11fa1f..188fa3aa386 100644 --- a/opentelemetry-sdk/tests/trace/test_trace.py +++ b/opentelemetry-sdk/tests/trace/test_trace.py @@ -20,7 +20,7 @@ import unittest from importlib import reload from logging import ERROR, WARNING -from random import randint +from random import randint, Random from time import time_ns from typing import Optional from unittest import mock @@ -2168,29 +2168,29 @@ class TestRandomIdGenerator(unittest.TestCase): _TRACE_ID_MAX_VALUE = 2**128 - 1 _SPAN_ID_MAX_VALUE = 2**64 - 1 - @patch( - "random.getrandbits", - side_effect=[trace_api.INVALID_SPAN_ID, 0x00000000DEADBEF0], - ) - def test_generate_span_id_avoids_invalid(self, mock_getrandbits): - generator = RandomIdGenerator() - span_id = generator.generate_span_id() - - self.assertNotEqual(span_id, trace_api.INVALID_SPAN_ID) - mock_getrandbits.assert_any_call(64) - self.assertEqual(mock_getrandbits.call_count, 2) - - @patch( - "random.getrandbits", - side_effect=[ - trace_api.INVALID_TRACE_ID, - 0x000000000000000000000000DEADBEEF, - ], - ) - def test_generate_trace_id_avoids_invalid(self, mock_getrandbits): - generator = RandomIdGenerator() - trace_id = generator.generate_trace_id() + def setUp(self): + self.generators = { + 'rng=None': RandomIdGenerator(), + 'rng=Random()': RandomIdGenerator(rng=Random()), + 'rng=Random(42)': RandomIdGenerator(rng=Random(x=42)) + } - self.assertNotEqual(trace_id, trace_api.INVALID_TRACE_ID) - mock_getrandbits.assert_any_call(128) - self.assertEqual(mock_getrandbits.call_count, 2) + def test_generate_span_id_avoids_invalid(self): + for msg, generator in self.generators.items(): + with self.subTest(msg=msg), \ + patch.object(generator._rng, "getrandbits", side_effect=[trace_api.INVALID_SPAN_ID, 0x00000000DEADBEF0]) as mock_getrandbits: + span_id = generator.generate_span_id() + + self.assertNotEqual(span_id, trace_api.INVALID_SPAN_ID) + mock_getrandbits.assert_any_call(64) + self.assertEqual(mock_getrandbits.call_count, 2) + + def test_generate_trace_id_avoids_invalid(self): + for name, generator in self.generators.items(): + with self.subTest(msg=name), \ + patch.object(generator._rng, "getrandbits", side_effect=[trace_api.INVALID_TRACE_ID, 0x000000000000000000000000DEADBEEF]) as mock_getrandbits: + trace_id = generator.generate_trace_id() + + self.assertNotEqual(trace_id, trace_api.INVALID_TRACE_ID) + mock_getrandbits.assert_any_call(128) + self.assertEqual(mock_getrandbits.call_count, 2)