diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f252e8290..651adc9487 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- `opentelemetry-instrumentation-celery` Add three additional worker metrics to count active and prefetched tasks, as well as prefetch duration + ([#3463](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3463)) + +### Fixed + +- `opentelemetry-instrumentation-celery` Fix a memory leak where a reference to a task identifier is kept indefinitely + ([#3463](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3463)) + +### Breaking changes + +- `opentelemetry-instrumentation-celery` Rename `flower.task.runtime.seconds` metric to `messaging.process.duration` according to semconv + ([#3463](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3463)) + ## Version 1.36.0/0.57b0 (2025-07-29) ### Fixed diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py index 908f158507..85effe242d 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py @@ -76,6 +76,7 @@ def add(x, y): from opentelemetry.metrics import get_meter from opentelemetry.propagate import extract, inject from opentelemetry.propagators.textmap import Getter +from opentelemetry.semconv._incubating.metrics import messaging_metrics from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace.status import Status, StatusCode @@ -96,6 +97,12 @@ def add(x, y): _TASK_REVOKED_TERMINATED_SIGNAL_KEY = "celery.terminated.signal" _TASK_NAME_KEY = "celery.task_name" +# Metric names +_TASK_COUNT_ACTIVE = "messaging.client.active_tasks" +_TASK_COUNT_PREFETCHED = "messaging.client.prefetched_tasks" +_TASK_PROCESSING_TIME = messaging_metrics.MESSAGING_PROCESS_DURATION +_TASK_PREFETCH_TIME = "messaging.prefetch.duration" + class CeleryGetter(Getter): def get(self, carrier, key): @@ -113,10 +120,36 @@ def keys(self, carrier): celery_getter = CeleryGetter() -class CeleryInstrumentor(BaseInstrumentor): - metrics = None - task_id_to_start_time = {} +class TaskDurationTracker: + def __init__(self, metrics): + self.metrics = metrics + self.tracker = {} + + def record_start(self, key, step): + self.tracker.setdefault(key, {})[step] = default_timer() + + def record_finish(self, key, metric_name, attributes): + try: + time_elapsed = self._time_elapsed(key, metric_name) + self.metrics[metric_name].record( + max(0, time_elapsed), attributes=attributes + ) + except KeyError: + logger.warning("Failed to record %s for task %s", metric_name, key) + + def _time_elapsed(self, key, step): + end_time = default_timer() + try: + start_time = self.tracker.get(key, {}).pop(step) + time_elapsed = end_time - start_time + return time_elapsed + finally: + # Cleanup operation + if key in self.tracker and not self.tracker.get(key): + self.tracker.pop(key) + +class CeleryInstrumentor(BaseInstrumentor): def instrumentation_dependencies(self) -> Collection[str]: return _instruments @@ -139,8 +172,10 @@ def _instrument(self, **kwargs): schema_url="https://opentelemetry.io/schemas/1.11.0", ) - self.create_celery_metrics(meter) + self.metrics = _create_celery_worker_metrics(meter) + self.time_tracker = TaskDurationTracker(self.metrics) + signals.task_received.connect(self._trace_received, weak=False) signals.task_prerun.connect(self._trace_prerun, weak=False) signals.task_postrun.connect(self._trace_postrun, weak=False) signals.before_task_publish.connect( @@ -153,6 +188,7 @@ def _instrument(self, **kwargs): signals.task_retry.connect(self._trace_retry, weak=False) def _uninstrument(self, **kwargs): + signals.task_received.disconnect(self._trace_received) signals.task_prerun.disconnect(self._trace_prerun) signals.task_postrun.disconnect(self._trace_postrun) signals.before_task_publish.disconnect(self._trace_before_publish) @@ -160,20 +196,44 @@ def _uninstrument(self, **kwargs): signals.task_failure.disconnect(self._trace_failure) signals.task_retry.disconnect(self._trace_retry) + def _trace_received(self, *args, **kwargs): + """ + On receive signal, task is prefetched and prefetch timer starts + """ + + request = utils.retrieve_request(kwargs) + + metrics_attributes = utils.get_metrics_attributes_from_request(request) + self.metrics[_TASK_COUNT_PREFETCHED].add( + 1, attributes=metrics_attributes + ) + self.time_tracker.record_start(request.task_id, _TASK_PREFETCH_TIME) + def _trace_prerun(self, *args, **kwargs): + """ + On prerun signal, task is no longer prefetched, and execution timer + starts along with the task span + """ + task = utils.retrieve_task(kwargs) task_id = utils.retrieve_task_id(kwargs) if task is None or task_id is None: return - self.update_task_duration_time(task_id) + metrics_attributes = utils.get_metrics_attributes_from_task(task) + self.metrics[_TASK_COUNT_PREFETCHED].add( + -1, attributes=metrics_attributes + ) + self.time_tracker.record_finish( + task_id, _TASK_PREFETCH_TIME, metrics_attributes + ) + self.time_tracker.record_start(task_id, _TASK_PROCESSING_TIME) + request = task.request tracectx = extract(request, getter=celery_getter) or None token = context_api.attach(tracectx) if tracectx is not None else None - logger.debug("prerun signal start task_id=%s", task_id) - operation_name = f"{_TASK_RUN}/{task.name}" span = self._tracer.start_span( operation_name, context=tracectx, kind=trace.SpanKind.CONSUMER @@ -183,14 +243,24 @@ def _trace_prerun(self, *args, **kwargs): activation.__enter__() # pylint: disable=E1101 utils.attach_context(task, task_id, span, activation, token) + self.metrics[_TASK_COUNT_ACTIVE].add(1, attributes=metrics_attributes) + def _trace_postrun(self, *args, **kwargs): + """ + On postrun signal, task is no longer being executed + """ + task = utils.retrieve_task(kwargs) task_id = utils.retrieve_task_id(kwargs) if task is None or task_id is None: return - logger.debug("postrun signal task_id=%s", task_id) + metrics_attributes = utils.get_metrics_attributes_from_task(task) + self.metrics[_TASK_COUNT_ACTIVE].add(-1, attributes=metrics_attributes) + self.time_tracker.record_finish( + task_id, _TASK_PROCESSING_TIME, metrics_attributes + ) # retrieve and finish the Span ctx = utils.retrieve_context(task, task_id) @@ -210,10 +280,8 @@ def _trace_postrun(self, *args, **kwargs): activation.__exit__(None, None, None) utils.detach_context(task, task_id) - self.update_task_duration_time(task_id) - labels = {"task": task.name, "worker": task.request.hostname} - self._record_histograms(task_id, labels) - # if the process sending the task is not instrumented + + # If the process sending the task is not instrumented, # there's no incoming context and no token to detach if token is not None: context_api.detach(token) @@ -345,29 +413,29 @@ def _trace_retry(*args, **kwargs): # something that isn't an `Exception` span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason)) - def update_task_duration_time(self, task_id): - cur_time = default_timer() - task_duration_time_until_now = ( - cur_time - self.task_id_to_start_time[task_id] - if task_id in self.task_id_to_start_time - else cur_time - ) - self.task_id_to_start_time[task_id] = task_duration_time_until_now - - def _record_histograms(self, task_id, metric_attributes): - if task_id is None: - return - self.metrics["flower.task.runtime.seconds"].record( - self.task_id_to_start_time.get(task_id), - attributes=metric_attributes, - ) - - def create_celery_metrics(self, meter) -> None: - self.metrics = { - "flower.task.runtime.seconds": meter.create_histogram( - name="flower.task.runtime.seconds", - unit="seconds", - description="The time it took to run the task.", - ) - } +def _create_celery_worker_metrics(meter) -> None: + metrics = { + _TASK_COUNT_ACTIVE: meter.create_up_down_counter( + name=_TASK_COUNT_ACTIVE, + unit="{message}", + description="Number of tasks currently being executed by the worker", + ), + _TASK_COUNT_PREFETCHED: meter.create_up_down_counter( + name=_TASK_COUNT_PREFETCHED, + unit="{message}", + description="Number of tasks prefetched by the worker", + ), + _TASK_PREFETCH_TIME: meter.create_histogram( + name=_TASK_PREFETCH_TIME, + unit="s", + description="The time the task spent in prefetch mode", + ), + _TASK_PROCESSING_TIME: meter.create_histogram( + name=_TASK_PROCESSING_TIME, + unit="s", + description="The time it took to run the task.", + ), + } + + return metrics diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py index d7ca77af8a..04e0611f2c 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/utils.py @@ -20,6 +20,10 @@ from celery import registry # pylint: disable=no-name-in-module from celery.app.task import Task +from opentelemetry.semconv._incubating.attributes.messaging_attributes import ( + MESSAGING_CLIENT_ID, + MESSAGING_OPERATION_NAME, +) from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Span @@ -217,6 +221,14 @@ def retrieve_task_id(kwargs): return task_id +def retrieve_request(kwargs): + request = kwargs.get("request") + if request is None: + logger.debug("Unable to retrieve the request from signal arguments") + + return request + + def retrieve_task_id_from_request(kwargs): # retry signal does not include task_id as argument so use request argument request = kwargs.get("request") @@ -250,3 +262,17 @@ def retrieve_reason(kwargs): if not reason: logger.debug("Unable to retrieve the retry reason") return reason + + +def get_metrics_attributes_from_request(request): + return { + MESSAGING_OPERATION_NAME: request.task.name, + MESSAGING_CLIENT_ID: request.hostname, + } + + +def get_metrics_attributes_from_task(task): + return { + MESSAGING_OPERATION_NAME: task.name, + MESSAGING_CLIENT_ID: task.request.hostname, + } diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py b/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py index af88f1d4c3..e6580639f9 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/celery_test_tasks.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + from celery import Celery from opentelemetry import baggage class Config: - result_backend = "rpc" - broker_backend = "memory" + result_backend = "rpc://" + without_gossip = True + without_heartbeat = True + without_mingle = True app = Celery(broker="memory:///") @@ -31,8 +35,14 @@ class CustomError(Exception): @app.task -def task_add(num_a, num_b): - return num_a + num_b +def task_add(x=1, y=2): + return x + y + + +@app.task +def task_sleep(sleep_time): + time.sleep(sleep_time) + return 1 @app.task diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py index ab1f7804cf..3c2cef9d2b 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_duplicate.py @@ -26,5 +26,3 @@ def test_duplicate_instrumentaion(self): CeleryInstrumentor().uninstrument() self.assertIsNotNone(first.metrics) self.assertIsNotNone(second.metrics) - self.assertEqual(first.task_id_to_start_time, {}) - self.assertEqual(second.task_id_to_start_time, {}) diff --git a/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py b/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py index f83759317b..a0a862d89b 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py +++ b/instrumentation/opentelemetry-instrumentation-celery/tests/test_metrics.py @@ -1,22 +1,52 @@ +import logging import threading import time -from platform import python_implementation -from timeit import default_timer -from pytest import mark +import pytest +from celery.signals import task_postrun, task_prerun, task_received -from opentelemetry.instrumentation.celery import CeleryInstrumentor +from opentelemetry.instrumentation.celery import ( + _TASK_COUNT_ACTIVE, + _TASK_COUNT_PREFETCHED, + _TASK_PREFETCH_TIME, + _TASK_PROCESSING_TIME, + CeleryInstrumentor, + TaskDurationTracker, + _create_celery_worker_metrics, +) +from opentelemetry.metrics import get_meter +from opentelemetry.semconv._incubating.attributes.messaging_attributes import ( + MESSAGING_CLIENT_ID, + MESSAGING_OPERATION_NAME, +) from opentelemetry.test.test_base import TestBase -from .celery_test_tasks import app, task_add +from .celery_test_tasks import ( + app, + task_add, + task_raises, + task_sleep, +) +EXPECTED_METRICS = 4 + + +class TestCeleryMetrics(TestBase): + WORKER_NODE = "celery@hostname" -class TestMetrics(TestBase): def setUp(self): super().setUp() self._worker = app.Worker( - app=app, pool="solo", concurrency=1, hostname="celery@akochavi" + app=app, + pool="threads", + concurrency=1, + hostname=self.WORKER_NODE, + loglevel="INFO", + without_mingle=True, + without_heartbeat=True, + without_gossip=True, ) + self._thread = threading.Thread(target=self._worker.start) self._thread.daemon = True self._thread.start() @@ -26,88 +56,360 @@ def tearDown(self): self._worker.stop() self._thread.join() - def get_metrics(self): - result = task_add.delay(1, 2) + def wait_for_tasks_to_finish(self, timeout=30): + """Blocks until all tasks in Celery worker are finished""" - timeout = time.time() + 60 * 1 # 1 minutes from now - while not result.ready(): - if time.time() > timeout: + inspect = app.control.inspect() + start_time = time.time() + while True: + counter = 0 + for state in ( + inspect.scheduled(), + inspect.active(), + inspect.reserved(), + ): + if state is not None: + counter += len(state[self.WORKER_NODE]) + + if counter == 0: break + + time.sleep(0.5) + if time.time() - start_time > timeout: + raise TimeoutError( + "Timeout while waiting for tasks to finish." + ) + + def wait_for_metrics_until_finished(self, task_fn, *args): + """ + Create a task, wait for it to finish and return metrics + + This ensures that all metrics have been initialized. + """ + + result = task_fn.delay(*args) + give_up_time = time.time() + 30 + while not result.ready(): + if time.time() > give_up_time: + raise TimeoutError("Timeout while waiting for task to finish.") + time.sleep(0.05) return self.get_sorted_metrics() - def test_basic_metric(self): + def test_counters_are_correct(self): + """ + Test that prefetch and execution task counters are counting correctly + """ + + CeleryInstrumentor().instrument() + + task_add.delay() + self.wait_for_tasks_to_finish() + + task_sleep.delay(2) + task_sleep.delay(2) + task_sleep.delay(2) + + time.sleep(1) + + exported_metrics = self.get_sorted_metrics() + self.assertEqual(len(exported_metrics), EXPECTED_METRICS) + + # Worker is single-threaded, so we expect 1 task_sleep to be running and + # 2 task_sleep to be waiting in queue + for metric in exported_metrics: + data_point = metric.data.data_points[0] + + if not data_point.attributes.get( + MESSAGING_OPERATION_NAME + ).endswith("task_sleep"): + continue + + if metric.name == _TASK_COUNT_ACTIVE: + self.assertEqual(data_point.value, 1) + + if metric.name == _TASK_COUNT_PREFETCHED: + self.assertEqual(data_point.value, 2) + + self.memory_exporter.clear() + CeleryInstrumentor().uninstrument() + + def test_counters_with_task_errors(self): + """ + Test that counters are working well even if task is raising errors + """ + + CeleryInstrumentor().instrument() + + task_raises.delay() + self.wait_for_tasks_to_finish() + + task_sleep.delay(2) + task_sleep.delay(2) + + time.sleep(1) + + exported_metrics = self.get_sorted_metrics() + self.assertEqual(len(exported_metrics), EXPECTED_METRICS) + + for metric in exported_metrics: + data_point = metric.data.data_points[0] + + if not data_point.attributes.get( + MESSAGING_OPERATION_NAME + ).endswith("task_sleep"): + continue + + if metric.name == _TASK_COUNT_ACTIVE: + self.assertEqual(data_point.value, 1) + + if metric.name == _TASK_COUNT_PREFETCHED: + self.assertEqual(data_point.value, 1) + + # After processing is finished, all counters should be at 0 + self.wait_for_tasks_to_finish() + exported_metrics = self.get_sorted_metrics() + + for metric in exported_metrics: + data_point = metric.data.data_points[0] + + if not data_point.attributes.get( + MESSAGING_OPERATION_NAME + ).endswith("task_sleep"): + continue + + if metric.name == _TASK_COUNT_ACTIVE: + self.assertEqual(data_point.value, 0) + + if metric.name == _TASK_COUNT_PREFETCHED: + self.assertEqual(data_point.value, 0) + + self.memory_exporter.clear() + CeleryInstrumentor().uninstrument() + + def test_counters_with_revoked_task(self): CeleryInstrumentor().instrument() - start_time = default_timer() - task_runtime_estimated = (default_timer() - start_time) * 1000 - metrics = self.get_metrics() + self.wait_for_metrics_until_finished(task_add) + + task_sleep.delay(2) + task2 = task_sleep.delay(2) + task2.revoke() + + self.wait_for_tasks_to_finish() + + exported_metrics = self.get_sorted_metrics() + self.assertEqual(len(exported_metrics), EXPECTED_METRICS) + + for metric in exported_metrics: + data_point = metric.data.data_points[0] + + if not data_point.attributes.get( + MESSAGING_OPERATION_NAME + ).endswith("task_sleep"): + continue + + if metric.name == _TASK_COUNT_ACTIVE: + self.assertEqual(data_point.value, 0) + + if metric.name == _TASK_COUNT_PREFETCHED: + self.assertEqual(data_point.value, 0) + + self.memory_exporter.clear() CeleryInstrumentor().uninstrument() - self.assertEqual(len(metrics), 1) - task_runtime = metrics[0] - print(task_runtime) - self.assertEqual(task_runtime.name, "flower.task.runtime.seconds") + def test_prefetch_duration_metric(self): + CeleryInstrumentor().instrument() + + expected_prefetch_seconds = 1 + + self.wait_for_metrics_until_finished(task_add) + task_sleep.delay(1) + task_sleep.delay(1) + + self.wait_for_tasks_to_finish() + + exported_metrics = self.get_sorted_metrics() + self.assertEqual(len(exported_metrics), EXPECTED_METRICS) + + task_runtime = [ + x for x in exported_metrics if x.name == _TASK_PREFETCH_TIME + ][0] + self.assertEqual(task_runtime.name, _TASK_PREFETCH_TIME) self.assert_metric_expected( task_runtime, [ self.create_histogram_data_point( count=1, - sum_data_point=task_runtime_estimated, - max_data_point=task_runtime_estimated, - min_data_point=task_runtime_estimated, + sum_data_point=0, + max_data_point=0, + min_data_point=0, attributes={ - "task": "tests.celery_test_tasks.task_add", - "worker": "celery@akochavi", + MESSAGING_OPERATION_NAME: "tests.celery_test_tasks.task_add", + MESSAGING_CLIENT_ID: self.WORKER_NODE, }, - ) + ), + self.create_histogram_data_point( + count=2, + sum_data_point=expected_prefetch_seconds, + max_data_point=expected_prefetch_seconds, + min_data_point=0, # First sleep task did not have to wait + attributes={ + MESSAGING_OPERATION_NAME: "tests.celery_test_tasks.task_sleep", + MESSAGING_CLIENT_ID: self.WORKER_NODE, + }, + ), ], - est_value_delta=200, + est_value_delta=0.05, ) - @mark.skipif( - python_implementation() == "PyPy", reason="Fails randomly in pypy" - ) - def test_metric_uninstrument(self): + self.memory_exporter.clear() + CeleryInstrumentor().uninstrument() + + def test_execution_duration_metric(self): CeleryInstrumentor().instrument() - self.get_metrics() - self.assertEqual( - ( - self.memory_metrics_reader.get_metrics_data() - .resource_metrics[0] - .scope_metrics[0] - .metrics[0] - .data.data_points[0] - .bucket_counts[1] - ), - 1, - ) + expected_runtime_seconds = 2 + + self.wait_for_metrics_until_finished(task_add) + task_sleep.delay(2) + + self.wait_for_tasks_to_finish() + + exported_metrics = self.get_sorted_metrics() + self.assertEqual(len(exported_metrics), EXPECTED_METRICS) - self.get_metrics() - self.assertEqual( - ( - self.memory_metrics_reader.get_metrics_data() - .resource_metrics[0] - .scope_metrics[0] - .metrics[0] - .data.data_points[0] - .bucket_counts[1] - ), - 2, + task_runtime = [ + x for x in exported_metrics if x.name == _TASK_PROCESSING_TIME + ][0] + self.assert_metric_expected( + task_runtime, + [ + self.create_histogram_data_point( + count=1, + sum_data_point=0, + max_data_point=0, + min_data_point=0, + attributes={ + MESSAGING_OPERATION_NAME: "tests.celery_test_tasks.task_add", + MESSAGING_CLIENT_ID: self.WORKER_NODE, + }, + ), + self.create_histogram_data_point( + count=1, + sum_data_point=expected_runtime_seconds, + max_data_point=expected_runtime_seconds, + min_data_point=expected_runtime_seconds, + attributes={ + MESSAGING_OPERATION_NAME: "tests.celery_test_tasks.task_sleep", + MESSAGING_CLIENT_ID: self.WORKER_NODE, + }, + ), + ], + est_value_delta=0.05, ) + self.memory_exporter.clear() CeleryInstrumentor().uninstrument() - self.get_metrics() - self.assertEqual( - ( - self.memory_metrics_reader.get_metrics_data() - .resource_metrics[0] - .scope_metrics[0] - .metrics[0] - .data.data_points[0] - .bucket_counts[1] - ), - 2, - ) + def test_exported_metrics(self): + """ + Test that number of exported metrics and metrics attributes are as + expected + """ + + CeleryInstrumentor().instrument() + + expected_attributes = { + MESSAGING_OPERATION_NAME: "tests.celery_test_tasks.task_add", + MESSAGING_CLIENT_ID: self.WORKER_NODE, + } + + self.wait_for_metrics_until_finished(task_add) + + metrics = self.get_sorted_metrics() + + self.assertEqual(len(metrics), EXPECTED_METRICS) + for metric in metrics: + for data_point in metric.data.data_points: + self.assertEqual( + expected_attributes, + (dict(data_point.attributes)), + ) + + self.memory_exporter.clear() + CeleryInstrumentor().uninstrument() + + def test_uninstrument_metrics(self): + """ + Even after memory exporter gets cleared, it is still returning metrics, + so this just checks that subscribers are disconnected from Celery + events. + """ + + CeleryInstrumentor().instrument() + self.wait_for_metrics_until_finished(task_add) + + self.assertEqual(len(task_prerun.receivers), 1) + self.assertEqual(len(task_received.receivers), 1) + self.assertEqual(len(task_postrun.receivers), 1) + + self.memory_exporter.clear() + CeleryInstrumentor().uninstrument() + + time.sleep(1) + task_add.delay() + time.sleep(1) + + self.assertEqual(len(task_prerun.receivers), 0) + self.assertEqual(len(task_received.receivers), 0) + self.assertEqual(len(task_postrun.receivers), 0) + + def test_no_memory_leak_because_of_time_tracking(self): + """ + To test that time tracking helper class does not keep references to a + finished task indefinitely + """ + + celery_instrumentor = CeleryInstrumentor() + celery_instrumentor.instrument() + + for _ in range(5): + task_add.delay() + + self.wait_for_tasks_to_finish() + + exported_metrics = self.get_sorted_metrics() + self.assertEqual(len(exported_metrics), EXPECTED_METRICS) + + self.assertEqual(len(celery_instrumentor.time_tracker.tracker), 0) + + self.memory_exporter.clear() + celery_instrumentor.uninstrument() + + +class TestDurationTracker(TestBase): + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self.caplog = caplog # pylint: disable=attribute-defined-outside-init + + def test_duration_tracker(self): + metrics = _create_celery_worker_metrics(get_meter(self.meter_provider)) + sample_hist_metric = _TASK_PROCESSING_TIME + + tracker = TaskDurationTracker(metrics) + tracker.record_start("task-id-123", sample_hist_metric) + + # Robustness to undefined keys + with self.caplog.at_level(logging.WARNING): + tracker.record_finish("task-id-456", sample_hist_metric, {}) + self.assertIn("Failed to record", self.caplog.text) + + with self.caplog.at_level(logging.WARNING): + tracker.record_finish("task-id-123", "non_existent_metric", {}) + self.assertIn("Failed to record", self.caplog.text) + + tracker.record_finish("task-id-123", sample_hist_metric, {}) + + exported_metrics = self.get_sorted_metrics() + self.assertEqual(exported_metrics[0].data.data_points[0].count, 1)