Skip to content

Commit ec5f440

Browse files
authored
Register atexit callback once per configure to prevent slowdown with each span (#378)
1 parent a5c5ad8 commit ec5f440

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

logfire/_internal/config.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import atexit
34
import dataclasses
45
import functools
56
import json
@@ -13,9 +14,10 @@
1314
from functools import cached_property
1415
from pathlib import Path
1516
from threading import RLock, Thread
16-
from typing import Any, Callable, Literal, Sequence, cast
17+
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence, cast
1718
from urllib.parse import urljoin
1819
from uuid import uuid4
20+
from weakref import WeakSet
1921

2022
import requests
2123
from opentelemetry import metrics, trace
@@ -81,6 +83,12 @@
8183
from .tracer import PendingSpanProcessor, ProxyTracerProvider
8284
from .utils import UnexpectedResponse, ensure_data_dir_exists, get_version, read_toml_file, suppress_instrumentation
8385

86+
if TYPE_CHECKING:
87+
from .main import FastLogfireSpan, LogfireSpan
88+
89+
# NOTE: this WeakSet is the reason that FastLogfireSpan.__slots__ has a __weakref__ slot.
90+
OPEN_SPANS: WeakSet[LogfireSpan | FastLogfireSpan] = WeakSet()
91+
8492
CREDENTIALS_FILENAME = 'logfire_credentials.json'
8593
"""Default base URL for the Logfire API."""
8694
COMMON_REQUEST_HEADERS = {'User-Agent': f'logfire/{VERSION}'}
@@ -767,6 +775,15 @@ def check_token():
767775
trace.set_tracer_provider(self._tracer_provider)
768776
metrics.set_meter_provider(self._meter_provider)
769777

778+
@atexit.register
779+
def _exit_open_spans(): # type: ignore[reportUnusedFunction] # pragma: no cover
780+
# Ensure that all open spans are closed when the program exits.
781+
# OTEL registers its own atexit callback in the tracer/meter providers to shut them down.
782+
# Registering this callback here after the OTEL one means that this runs first.
783+
# Otherwise OTEL would log an error "Already shutdown, dropping span."
784+
for span in list(OPEN_SPANS):
785+
span.__exit__(None, None, None)
786+
770787
self._initialized = True
771788

772789
# set up context propagation for ThreadPoolExecutor and ProcessPoolExecutor

logfire/_internal/main.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from __future__ import annotations
22

3-
import atexit
43
import inspect
54
import sys
65
import traceback
76
import warnings
8-
from functools import cached_property, partial
7+
from functools import cached_property
98
from time import time
109
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Literal, Sequence, TypeVar, Union, cast
1110

@@ -21,7 +20,7 @@
2120
from ..version import VERSION
2221
from . import async_
2322
from .auto_trace import AutoTraceModule, install_auto_tracing
24-
from .config import GLOBAL_CONFIG, LogfireConfig
23+
from .config import GLOBAL_CONFIG, OPEN_SPANS, LogfireConfig
2524
from .constants import (
2625
ATTRIBUTES_JSON_SCHEMA_KEY,
2726
ATTRIBUTES_LOG_LEVEL_NUM_KEY,
@@ -1577,20 +1576,20 @@ def shutdown(self, timeout_millis: int = 30_000, flush: bool = True) -> bool: #
15771576
class FastLogfireSpan:
15781577
"""A simple version of `LogfireSpan` optimized for auto-tracing."""
15791578

1580-
__slots__ = ('_span', '_token', '_atexit')
1579+
# __weakref__ is needed for the OPEN_SPANS WeakSet.
1580+
__slots__ = ('_span', '_token', '__weakref__')
15811581

15821582
def __init__(self, span: trace_api.Span) -> None:
15831583
self._span = span
15841584
self._token = context_api.attach(trace_api.set_span_in_context(self._span))
1585-
self._atexit = partial(self.__exit__, None, None, None)
1586-
atexit.register(self._atexit)
1585+
OPEN_SPANS.add(self)
15871586

15881587
def __enter__(self) -> FastLogfireSpan:
15891588
return self
15901589

15911590
@handle_internal_errors()
15921591
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
1593-
atexit.unregister(self._atexit)
1592+
OPEN_SPANS.remove(self)
15941593
context_api.detach(self._token)
15951594
_exit_span(self._span, exc_value)
15961595
self._span.end()
@@ -1615,7 +1614,6 @@ def __init__(
16151614
self._token: None | object = None
16161615
self._span: None | trace_api.Span = None
16171616
self.end_on_exit = True
1618-
self._atexit: Callable[[], None] | None = None
16191617

16201618
if not TYPE_CHECKING: # pragma: no branch
16211619

@@ -1633,8 +1631,7 @@ def __enter__(self) -> LogfireSpan:
16331631
if self._token is None: # pragma: no branch
16341632
self._token = context_api.attach(trace_api.set_span_in_context(self._span))
16351633

1636-
self._atexit = partial(self.__exit__, None, None, None)
1637-
atexit.register(self._atexit)
1634+
OPEN_SPANS.add(self)
16381635

16391636
return self
16401637

@@ -1643,8 +1640,7 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseExceptio
16431640
if self._token is None: # pragma: no cover
16441641
return
16451642

1646-
if self._atexit: # pragma: no branch
1647-
atexit.unregister(self._atexit)
1643+
OPEN_SPANS.remove(self)
16481644

16491645
context_api.detach(self._token)
16501646
self._token = None
@@ -1656,8 +1652,6 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseExceptio
16561652
if end_on_exit_:
16571653
self.end()
16581654

1659-
self._token = None
1660-
16611655
@property
16621656
def message_template(self) -> str | None: # pragma: no cover
16631657
return self._get_attribute(ATTRIBUTES_MESSAGE_TEMPLATE_KEY, None)

tests/otel_integrations/test_celery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import sys
23
from typing import Generator, Iterator
34

45
import pytest
@@ -13,6 +14,9 @@
1314
import logfire
1415
from logfire.testing import TestExporter
1516

17+
# TODO find a better solution
18+
pytestmark = pytest.mark.skipif(sys.version_info < (3, 9), reason='Redis testcontainers has problems in 3.8')
19+
1620

1721
@pytest.fixture(scope='module', autouse=True)
1822
def redis_container() -> Generator[RedisContainer, None, None]:

0 commit comments

Comments
 (0)