diff --git a/docs/guides/onboarding-checklist/add-manual-tracing.md b/docs/guides/onboarding-checklist/add-manual-tracing.md index 8a0b62c8d..f69192dcc 100644 --- a/docs/guides/onboarding-checklist/add-manual-tracing.md +++ b/docs/guides/onboarding-checklist/add-manual-tracing.md @@ -234,6 +234,15 @@ my_function(3, 4) # Logs: Applying my_function to x=3 and y=4 ``` +Access the created span inside the function using `logfire.current_span()`: + +```python +@logfire.instrument('Processing') +def process_data(user_dict: dict): + user_id = user_dict.get("id") + logfire.current_span().message = f'Processing User: {user_id}' +``` + !!! note - The [`@logfire.instrument`][logfire.Logfire.instrument] decorator MUST be applied first, i.e., UNDER any other decorators. diff --git a/logfire-api/logfire_api/__init__.py b/logfire-api/logfire_api/__init__.py index 4509704d2..62c90af62 100644 --- a/logfire-api/logfire_api/__init__.py +++ b/logfire-api/logfire_api/__init__.py @@ -96,6 +96,9 @@ def __init__(self, *args, **kwargs) -> None: ... def span(self, *args, **kwargs) -> LogfireSpan: return LogfireSpan() + def current_span(self) -> LogfireSpan: + return LogfireSpan() + def log(self, *args, **kwargs) -> None: ... def trace(self, *args, **kwargs) -> None: ... @@ -201,6 +204,7 @@ def shutdown(self, *args, **kwargs) -> None: ... DEFAULT_LOGFIRE_INSTANCE = Logfire() span = DEFAULT_LOGFIRE_INSTANCE.span + current_span = DEFAULT_LOGFIRE_INSTANCE.current_span log = DEFAULT_LOGFIRE_INSTANCE.log trace = DEFAULT_LOGFIRE_INSTANCE.trace debug = DEFAULT_LOGFIRE_INSTANCE.debug diff --git a/logfire-api/logfire_api/__init__.pyi b/logfire-api/logfire_api/__init__.pyi index fe9eba578..31bc99e8c 100644 --- a/logfire-api/logfire_api/__init__.pyi +++ b/logfire-api/logfire_api/__init__.pyi @@ -15,10 +15,11 @@ from logfire.propagate import attach_context as attach_context, get_context as g from logfire.sampling import SamplingOptions as SamplingOptions from typing import Any -__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'AdvancedOptions', 'ConsoleOptions', 'CodeSource', 'PydanticPlugin', 'configure', 'span', 'instrument', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'warning', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_asgi', 'instrument_wsgi', 'instrument_pydantic', 'instrument_pydantic_ai', 'instrument_fastapi', 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', 'instrument_google_genai', 'instrument_litellm', 'instrument_print', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_aiohttp_server', 'instrument_sqlalchemy', 'instrument_sqlite3', 'instrument_aws_lambda', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'instrument_mcp', 'AutoTraceModule', 'with_tags', 'with_settings', 'suppress_scopes', 'shutdown', 'no_auto_trace', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'add_non_user_code_prefix', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'loguru_handler', 'SamplingOptions', 'MetricsOptions', 'logfire_info', 'get_baggage', 'set_baggage', 'get_context', 'attach_context'] +__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'AdvancedOptions', 'ConsoleOptions', 'CodeSource', 'PydanticPlugin', 'configure', 'span', 'instrument', 'current_span', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'warning', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_asgi', 'instrument_wsgi', 'instrument_pydantic', 'instrument_pydantic_ai', 'instrument_fastapi', 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', 'instrument_google_genai', 'instrument_litellm', 'instrument_print', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_aiohttp_server', 'instrument_sqlalchemy', 'instrument_sqlite3', 'instrument_aws_lambda', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'instrument_mcp', 'AutoTraceModule', 'with_tags', 'with_settings', 'suppress_scopes', 'shutdown', 'no_auto_trace', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'add_non_user_code_prefix', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'loguru_handler', 'SamplingOptions', 'MetricsOptions', 'logfire_info', 'get_baggage', 'set_baggage', 'get_context', 'attach_context'] DEFAULT_LOGFIRE_INSTANCE = Logfire() span = DEFAULT_LOGFIRE_INSTANCE.span +current_span = DEFAULT_LOGFIRE_INSTANCE.current_span instrument = DEFAULT_LOGFIRE_INSTANCE.instrument force_flush = DEFAULT_LOGFIRE_INSTANCE.force_flush log_slow_async_callbacks = DEFAULT_LOGFIRE_INSTANCE.log_slow_async_callbacks diff --git a/logfire-api/logfire_api/_internal/main.pyi b/logfire-api/logfire_api/_internal/main.pyi index 5cbe73a12..6204f7449 100644 --- a/logfire-api/logfire_api/_internal/main.pyi +++ b/logfire-api/logfire_api/_internal/main.pyi @@ -237,6 +237,10 @@ class Logfire: attributes: The arguments to include in the span and format the message template with. Attributes starting with an underscore are not allowed. """ + def current_span(self) -> LogfireSpan: + """Get the current span. + Useful for accessing spans created by @logfire.instrument. + """ @overload def instrument(self, msg_template: LiteralString | None = None, *, span_name: str | None = None, extract_args: bool | Iterable[str] = True, record_return: bool = False, allow_generator: bool = False) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator for instrumenting a function as a span. diff --git a/logfire/__init__.py b/logfire/__init__.py index 2badb3b4d..6c7c7fc06 100644 --- a/logfire/__init__.py +++ b/logfire/__init__.py @@ -24,6 +24,7 @@ DEFAULT_LOGFIRE_INSTANCE: Logfire = Logfire() span = DEFAULT_LOGFIRE_INSTANCE.span instrument = DEFAULT_LOGFIRE_INSTANCE.instrument +current_span = DEFAULT_LOGFIRE_INSTANCE.current_span force_flush = DEFAULT_LOGFIRE_INSTANCE.force_flush log_slow_async_callbacks = DEFAULT_LOGFIRE_INSTANCE.log_slow_async_callbacks install_auto_tracing = DEFAULT_LOGFIRE_INSTANCE.install_auto_tracing @@ -108,6 +109,7 @@ def loguru_handler() -> Any: 'configure', 'span', 'instrument', + 'current_span', 'log', 'trace', 'debug', diff --git a/logfire/_internal/ast_utils.py b/logfire/_internal/ast_utils.py index 22ed9b6c2..c37608f48 100644 --- a/logfire/_internal/ast_utils.py +++ b/logfire/_internal/ast_utils.py @@ -4,6 +4,7 @@ import functools import inspect import sys +import textwrap import types import warnings from abc import ABC, abstractmethod @@ -245,6 +246,25 @@ class InspectArgumentsFailedWarning(Warning): pass +@functools.lru_cache(maxsize=1024) +def has_current_span_call(func: Any) -> bool: + """Check if a function contains calls to logfire.current_span().""" + try: + tree = ast.parse(textwrap.dedent(inspect.getsource(func))) + for node in ast.walk(tree): + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.attr == 'current_span' + and node.func.value.id == 'logfire' + ): + return True + return False + except (OSError, SyntaxError, TypeError, AttributeError): + return False + + @functools.lru_cache(maxsize=1024) def get_node_source_text(node: ast.AST, ex_source: executing.Source): """Returns some Python source code representing `node`. diff --git a/logfire/_internal/instrument.py b/logfire/_internal/instrument.py index 3d2b3fd98..253b9aa65 100644 --- a/logfire/_internal/instrument.py +++ b/logfire/_internal/instrument.py @@ -11,6 +11,7 @@ from opentelemetry.util import types as otel_types from typing_extensions import LiteralString, ParamSpec +from .ast_utils import has_current_span_call from .constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY, ATTRIBUTES_TAGS_KEY from .stack_info import get_filepath_attribute from .utils import safe_repr, uniquify_sequence @@ -61,7 +62,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: ) attributes = get_attributes(func, msg_template, tags) - open_span = get_open_span(logfire, attributes, span_name, extract_args, func) + uses_current_span = has_current_span_call(func) + open_span = get_open_span(logfire, attributes, span_name, extract_args, uses_current_span, func) if inspect.isgeneratorfunction(func): if not allow_generator: @@ -90,6 +92,9 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore with open_span(*func_args, **func_kwargs) as span: + token = None + if uses_current_span: + token = logfire._current_span_var.set(span) # type: ignore[protected-access] result = await func(*func_args, **func_kwargs) if record_return: # open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types. @@ -97,14 +102,21 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ig # Not sure if making get_open_span return a LogfireSpan when record_return is True # would be faster overall or if it would be worth the added complexity. set_user_attributes_on_raw_span(span._span, {'return': result}) + if token: + logfire._current_span_var.reset(token) # type: ignore[protected-access] return result else: # Same as the above, but without the async/await def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: with open_span(*func_args, **func_kwargs) as span: + token = None + if uses_current_span: + token = logfire._current_span_var.set(span) # type: ignore[protected-access] result = func(*func_args, **func_kwargs) if record_return: set_user_attributes_on_raw_span(span._span, {'return': result}) + if token: + logfire._current_span_var.reset(token) # type: ignore[protected-access] return result wrapper = functools.wraps(func)(wrapper) # type: ignore @@ -118,12 +130,15 @@ def get_open_span( attributes: dict[str, otel_types.AttributeValue], span_name: str | None, extract_args: bool | Iterable[str], + uses_current_span: bool, func: Callable[P, R], ) -> Callable[P, AbstractContextManager[Any]]: final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore # This is the fast case for when there are no arguments to extract def open_span(*_: P.args, **__: P.kwargs): # type: ignore + if uses_current_span: + return logfire._span(final_span_name, attributes) # type: ignore[protected-access] return logfire._fast_span(final_span_name, attributes) # type: ignore if extract_args is True: @@ -134,6 +149,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): bound = sig.bind(*func_args, **func_kwargs) bound.apply_defaults() args_dict = bound.arguments + if uses_current_span: + return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access] + return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) @@ -165,6 +183,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): # This line is the only difference from the extract_args=True case args_dict = {k: args_dict[k] for k in extract_args_final} + if uses_current_span: + return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access] + return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 120499f2f..559cbbcdf 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -7,7 +7,7 @@ import warnings from collections.abc import Iterable, Sequence from contextlib import AbstractContextManager -from contextvars import Token +from contextvars import ContextVar, Token from enum import Enum from functools import cached_property from time import time @@ -138,6 +138,7 @@ def __init__( self._sample_rate = sample_rate self._console_log = console_log self._otel_scope = otel_scope + self._current_span_var = ContextVar('logfire_current_span', default=NoopSpan()) @property def config(self) -> LogfireConfig: @@ -171,6 +172,9 @@ def _get_tracer(self, *, is_span_tracer: bool) -> Tracer: # pragma: no cover is_span_tracer=is_span_tracer, ) + def current_span(self) -> LogfireSpan: + return self._current_span_var.get() # type: ignore[return-value] + # If any changes are made to this method, they may need to be reflected in `_fast_span` as well. def _span( self, diff --git a/tests/test_logfire.py b/tests/test_logfire.py index a8647b5e3..3ba5cce03 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -38,7 +38,7 @@ LevelName, ) from logfire._internal.formatter import FormattingFailedWarning -from logfire._internal.main import NoopSpan +from logfire._internal.main import LogfireSpan, NoopSpan from logfire._internal.tracer import record_exception from logfire._internal.utils import SeededRandomIdGenerator, is_instrumentation_suppressed from logfire.integrations.logging import LogfireLoggingHandler @@ -1344,6 +1344,111 @@ def run(a: str) -> None: ) +def test_current_span_default_is_noop(): + span = logfire.DEFAULT_LOGFIRE_INSTANCE.current_span() + assert isinstance(span, NoopSpan) + + +def test_current_span_nested(exporter: TestExporter): + @logfire.instrument('outer') + def outer(): + s = logfire.current_span() + s.message = 'Starting outer operation' + inner() + assert logfire.current_span() == s + s = logfire.current_span() + s.message = 'Completing outer operation' + + @logfire.instrument('inner') + def inner(): + s = logfire.current_span() + s.message = 'Processing inner operation' + + outer() + assert isinstance(logfire.current_span(), NoopSpan) + + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + + assert len(spans) == 2 + assert spans[0]['name'] == 'inner' + assert spans[0]['attributes']['logfire.msg'] == 'Processing inner operation' + assert spans[1]['name'] == 'outer' + assert spans[1]['attributes']['logfire.msg'] == 'Completing outer operation' + + +@pytest.mark.anyio +async def test_current_span_async(exporter: TestExporter): + print('Testing current_span_async - functions should auto-detect current_span usage') + + @logfire.instrument('async outer') + async def outer(): + s = logfire.current_span() + assert isinstance(s, LogfireSpan) + s.message = 'Starting async outer operation' + await inner() + s = logfire.current_span() + assert isinstance(s, LogfireSpan) + s.message = 'Completing async outer operation' + + @logfire.instrument('async inner') + async def inner(): + s = logfire.current_span() + assert isinstance(s, LogfireSpan) + s.message = 'Processing async inner operation' + + await outer() + assert isinstance(logfire.current_span(), NoopSpan) + + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + + assert len(spans) == 2 + assert spans[0]['name'] == 'async inner' + assert spans[0]['attributes']['logfire.msg'] == 'Processing async inner operation' + assert spans[1]['name'] == 'async outer' + assert spans[1]['attributes']['logfire.msg'] == 'Completing async outer operation' + + +def test_fast_span_when_current_span_not_called(exporter: TestExporter): + """Test that when current_span() is not called, the exported span is a fast span""" + + @logfire.instrument + def fast_operation(): ... + + fast_operation() + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + + # pending_span would be a LogfireSpan, just span is FastLogfireSpan + assert spans[0]['attributes']['logfire.span_type'] == 'span' + + +def test_instrument_extract_args_true_with_current_span(exporter: TestExporter): + """Test instrument with extract_args=True and current_span() usage.""" + + @logfire.instrument(extract_args=True, record_return=True) + def foo(a: int, b: str) -> None: + logfire.current_span() + + foo(42, 'test') + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + assert len(spans) == 1 + assert spans[0]['attributes']['a'] == 42 + assert spans[0]['attributes']['b'] == 'test' + + +def test_instrument_extract_args_list_with_current_span(exporter: TestExporter): + """Test instrument with extract_args list and current_span() usage.""" + + @logfire.instrument(extract_args=['a', 'c']) + def foo(a: int, b: str, c: float) -> None: + logfire.current_span() + + foo(42, 'test', 3.14) + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + assert len(spans) == 1 + assert spans[0]['attributes']['a'] == 42 + assert spans[0]['attributes']['c'] == 3.14 + + @dataclass class Foo: x: int diff --git a/tests/test_logfire_api.py b/tests/test_logfire_api.py index fd07900f8..7920db480 100644 --- a/tests/test_logfire_api.py +++ b/tests/test_logfire_api.py @@ -273,6 +273,10 @@ def func() -> None: ... pass logfire__all__.remove('attach_context') + assert hasattr(logfire_api, 'current_span') + logfire_api.current_span() + logfire__all__.remove('current_span') + # If it's not empty, it means that some of the __all__ members are not tested. assert logfire__all__ == set(), logfire__all__