Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/guides/onboarding-checklist/add-manual-tracing.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ my_function(3, 4)
# Logs: Applying my_function to x=3 and y=4
```

Access the created span inside the function using `logfire.current_span()`:

```python
@logfire.instrument('Processing')
def process_data(user_dict: dict):
user_id = user_dict.get("id")
logfire.current_span().message = f'Processing User: {user_id}'
```

!!! note

- The [`@logfire.instrument`][logfire.Logfire.instrument] decorator MUST be applied first, i.e., UNDER any other decorators.
Expand Down
4 changes: 4 additions & 0 deletions logfire-api/logfire_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def __init__(self, *args, **kwargs) -> None: ...
def span(self, *args, **kwargs) -> LogfireSpan:
return LogfireSpan()

def current_span(self) -> LogfireSpan:
return LogfireSpan()

def log(self, *args, **kwargs) -> None: ...

def trace(self, *args, **kwargs) -> None: ...
Expand Down Expand Up @@ -201,6 +204,7 @@ def shutdown(self, *args, **kwargs) -> None: ...

DEFAULT_LOGFIRE_INSTANCE = Logfire()
span = DEFAULT_LOGFIRE_INSTANCE.span
current_span = DEFAULT_LOGFIRE_INSTANCE.current_span
log = DEFAULT_LOGFIRE_INSTANCE.log
trace = DEFAULT_LOGFIRE_INSTANCE.trace
debug = DEFAULT_LOGFIRE_INSTANCE.debug
Expand Down
3 changes: 2 additions & 1 deletion logfire-api/logfire_api/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ from logfire.propagate import attach_context as attach_context, get_context as g
from logfire.sampling import SamplingOptions as SamplingOptions
from typing import Any

__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'AdvancedOptions', 'ConsoleOptions', 'CodeSource', 'PydanticPlugin', 'configure', 'span', 'instrument', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'warning', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_asgi', 'instrument_wsgi', 'instrument_pydantic', 'instrument_pydantic_ai', 'instrument_fastapi', 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', 'instrument_google_genai', 'instrument_litellm', 'instrument_print', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_aiohttp_server', 'instrument_sqlalchemy', 'instrument_sqlite3', 'instrument_aws_lambda', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'instrument_mcp', 'AutoTraceModule', 'with_tags', 'with_settings', 'suppress_scopes', 'shutdown', 'no_auto_trace', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'add_non_user_code_prefix', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'loguru_handler', 'SamplingOptions', 'MetricsOptions', 'logfire_info', 'get_baggage', 'set_baggage', 'get_context', 'attach_context']
__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'AdvancedOptions', 'ConsoleOptions', 'CodeSource', 'PydanticPlugin', 'configure', 'span', 'instrument', 'current_span', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'warning', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_asgi', 'instrument_wsgi', 'instrument_pydantic', 'instrument_pydantic_ai', 'instrument_fastapi', 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', 'instrument_google_genai', 'instrument_litellm', 'instrument_print', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_aiohttp_server', 'instrument_sqlalchemy', 'instrument_sqlite3', 'instrument_aws_lambda', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'instrument_mcp', 'AutoTraceModule', 'with_tags', 'with_settings', 'suppress_scopes', 'shutdown', 'no_auto_trace', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'add_non_user_code_prefix', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'loguru_handler', 'SamplingOptions', 'MetricsOptions', 'logfire_info', 'get_baggage', 'set_baggage', 'get_context', 'attach_context']

DEFAULT_LOGFIRE_INSTANCE = Logfire()
span = DEFAULT_LOGFIRE_INSTANCE.span
current_span = DEFAULT_LOGFIRE_INSTANCE.current_span
instrument = DEFAULT_LOGFIRE_INSTANCE.instrument
force_flush = DEFAULT_LOGFIRE_INSTANCE.force_flush
log_slow_async_callbacks = DEFAULT_LOGFIRE_INSTANCE.log_slow_async_callbacks
Expand Down
4 changes: 4 additions & 0 deletions logfire-api/logfire_api/_internal/main.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ class Logfire:
attributes: The arguments to include in the span and format the message template with.
Attributes starting with an underscore are not allowed.
"""
def current_span(self) -> LogfireSpan:
"""Get the current span.
Useful for accessing spans created by @logfire.instrument.
"""
@overload
def instrument(self, msg_template: LiteralString | None = None, *, span_name: str | None = None, extract_args: bool | Iterable[str] = True, record_return: bool = False, allow_generator: bool = False) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator for instrumenting a function as a span.
Expand Down
2 changes: 2 additions & 0 deletions logfire/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DEFAULT_LOGFIRE_INSTANCE: Logfire = Logfire()
span = DEFAULT_LOGFIRE_INSTANCE.span
instrument = DEFAULT_LOGFIRE_INSTANCE.instrument
current_span = DEFAULT_LOGFIRE_INSTANCE.current_span
force_flush = DEFAULT_LOGFIRE_INSTANCE.force_flush
log_slow_async_callbacks = DEFAULT_LOGFIRE_INSTANCE.log_slow_async_callbacks
install_auto_tracing = DEFAULT_LOGFIRE_INSTANCE.install_auto_tracing
Expand Down Expand Up @@ -108,6 +109,7 @@ def loguru_handler() -> Any:
'configure',
'span',
'instrument',
'current_span',
'log',
'trace',
'debug',
Expand Down
20 changes: 20 additions & 0 deletions logfire/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import inspect
import sys
import textwrap
import types
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -245,6 +246,25 @@ class InspectArgumentsFailedWarning(Warning):
pass


@functools.lru_cache(maxsize=1024)
def has_current_span_call(func: Any) -> bool:
"""Check if a function contains calls to logfire.current_span()."""
try:
tree = ast.parse(textwrap.dedent(inspect.getsource(func)))
for node in ast.walk(tree):
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.attr == 'current_span'
and node.func.value.id == 'logfire'
):
return True
return False
except (OSError, SyntaxError, TypeError, AttributeError):
return False


@functools.lru_cache(maxsize=1024)
def get_node_source_text(node: ast.AST, ex_source: executing.Source):
"""Returns some Python source code representing `node`.
Expand Down
23 changes: 22 additions & 1 deletion logfire/_internal/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from opentelemetry.util import types as otel_types
from typing_extensions import LiteralString, ParamSpec

from .ast_utils import has_current_span_call
from .constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY, ATTRIBUTES_TAGS_KEY
from .stack_info import get_filepath_attribute
from .utils import safe_repr, uniquify_sequence
Expand Down Expand Up @@ -61,7 +62,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
)

attributes = get_attributes(func, msg_template, tags)
open_span = get_open_span(logfire, attributes, span_name, extract_args, func)
uses_current_span = has_current_span_call(func)
open_span = get_open_span(logfire, attributes, span_name, extract_args, uses_current_span, func)

if inspect.isgeneratorfunction(func):
if not allow_generator:
Expand Down Expand Up @@ -90,21 +92,31 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore

async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore
with open_span(*func_args, **func_kwargs) as span:
token = None
if uses_current_span:
token = logfire._current_span_var.set(span) # type: ignore[protected-access]
result = await func(*func_args, **func_kwargs)
if record_return:
# open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types.
# This isn't great because it has to parse the JSON schema.
# Not sure if making get_open_span return a LogfireSpan when record_return is True
# would be faster overall or if it would be worth the added complexity.
set_user_attributes_on_raw_span(span._span, {'return': result})
if token:
logfire._current_span_var.reset(token) # type: ignore[protected-access]
return result
else:
# Same as the above, but without the async/await
def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R:
with open_span(*func_args, **func_kwargs) as span:
token = None
if uses_current_span:
token = logfire._current_span_var.set(span) # type: ignore[protected-access]
result = func(*func_args, **func_kwargs)
if record_return:
set_user_attributes_on_raw_span(span._span, {'return': result})
if token:
logfire._current_span_var.reset(token) # type: ignore[protected-access]
return result

wrapper = functools.wraps(func)(wrapper) # type: ignore
Expand All @@ -118,12 +130,15 @@ def get_open_span(
attributes: dict[str, otel_types.AttributeValue],
span_name: str | None,
extract_args: bool | Iterable[str],
uses_current_span: bool,
func: Callable[P, R],
) -> Callable[P, AbstractContextManager[Any]]:
final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore

# This is the fast case for when there are no arguments to extract
def open_span(*_: P.args, **__: P.kwargs): # type: ignore
if uses_current_span:
return logfire._span(final_span_name, attributes) # type: ignore[protected-access]
return logfire._fast_span(final_span_name, attributes) # type: ignore

if extract_args is True:
Expand All @@ -134,6 +149,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
bound = sig.bind(*func_args, **func_kwargs)
bound.apply_defaults()
args_dict = bound.arguments
if uses_current_span:
return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access]

return logfire._instrument_span_with_args( # type: ignore
final_span_name, attributes, args_dict
)
Expand Down Expand Up @@ -165,6 +183,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
# This line is the only difference from the extract_args=True case
args_dict = {k: args_dict[k] for k in extract_args_final}

if uses_current_span:
return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access]

return logfire._instrument_span_with_args( # type: ignore
final_span_name, attributes, args_dict
)
Expand Down
6 changes: 5 additions & 1 deletion logfire/_internal/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from collections.abc import Iterable, Sequence
from contextlib import AbstractContextManager
from contextvars import Token
from contextvars import ContextVar, Token
from enum import Enum
from functools import cached_property
from time import time
Expand Down Expand Up @@ -138,6 +138,7 @@ def __init__(
self._sample_rate = sample_rate
self._console_log = console_log
self._otel_scope = otel_scope
self._current_span_var = ContextVar('logfire_current_span', default=NoopSpan())

@property
def config(self) -> LogfireConfig:
Expand Down Expand Up @@ -171,6 +172,9 @@ def _get_tracer(self, *, is_span_tracer: bool) -> Tracer: # pragma: no cover
is_span_tracer=is_span_tracer,
)

def current_span(self) -> LogfireSpan:
return self._current_span_var.get() # type: ignore[return-value]

# If any changes are made to this method, they may need to be reflected in `_fast_span` as well.
def _span(
self,
Expand Down
107 changes: 106 additions & 1 deletion tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
LevelName,
)
from logfire._internal.formatter import FormattingFailedWarning
from logfire._internal.main import NoopSpan
from logfire._internal.main import LogfireSpan, NoopSpan
from logfire._internal.tracer import record_exception
from logfire._internal.utils import SeededRandomIdGenerator, is_instrumentation_suppressed
from logfire.integrations.logging import LogfireLoggingHandler
Expand Down Expand Up @@ -1344,6 +1344,111 @@ def run(a: str) -> None:
)


def test_current_span_default_is_noop():
span = logfire.DEFAULT_LOGFIRE_INSTANCE.current_span()
assert isinstance(span, NoopSpan)


def test_current_span_nested(exporter: TestExporter):
@logfire.instrument('outer')
def outer():
s = logfire.current_span()
s.message = 'Starting outer operation'
inner()
assert logfire.current_span() == s
s = logfire.current_span()
s.message = 'Completing outer operation'

@logfire.instrument('inner')
def inner():
s = logfire.current_span()
s.message = 'Processing inner operation'

outer()
assert isinstance(logfire.current_span(), NoopSpan)

spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)

assert len(spans) == 2
assert spans[0]['name'] == 'inner'
assert spans[0]['attributes']['logfire.msg'] == 'Processing inner operation'
assert spans[1]['name'] == 'outer'
assert spans[1]['attributes']['logfire.msg'] == 'Completing outer operation'


@pytest.mark.anyio
async def test_current_span_async(exporter: TestExporter):
print('Testing current_span_async - functions should auto-detect current_span usage')

@logfire.instrument('async outer')
async def outer():
s = logfire.current_span()
assert isinstance(s, LogfireSpan)
s.message = 'Starting async outer operation'
await inner()
s = logfire.current_span()
assert isinstance(s, LogfireSpan)
s.message = 'Completing async outer operation'

@logfire.instrument('async inner')
async def inner():
s = logfire.current_span()
assert isinstance(s, LogfireSpan)
s.message = 'Processing async inner operation'

await outer()
assert isinstance(logfire.current_span(), NoopSpan)

spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)

assert len(spans) == 2
assert spans[0]['name'] == 'async inner'
assert spans[0]['attributes']['logfire.msg'] == 'Processing async inner operation'
assert spans[1]['name'] == 'async outer'
assert spans[1]['attributes']['logfire.msg'] == 'Completing async outer operation'


def test_fast_span_when_current_span_not_called(exporter: TestExporter):
"""Test that when current_span() is not called, the exported span is a fast span"""

@logfire.instrument
def fast_operation(): ...

fast_operation()
spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)

# pending_span would be a LogfireSpan, just span is FastLogfireSpan
assert spans[0]['attributes']['logfire.span_type'] == 'span'


def test_instrument_extract_args_true_with_current_span(exporter: TestExporter):
"""Test instrument with extract_args=True and current_span() usage."""

@logfire.instrument(extract_args=True, record_return=True)
def foo(a: int, b: str) -> None:
logfire.current_span()

foo(42, 'test')
spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)
assert len(spans) == 1
assert spans[0]['attributes']['a'] == 42
assert spans[0]['attributes']['b'] == 'test'


def test_instrument_extract_args_list_with_current_span(exporter: TestExporter):
"""Test instrument with extract_args list and current_span() usage."""

@logfire.instrument(extract_args=['a', 'c'])
def foo(a: int, b: str, c: float) -> None:
logfire.current_span()

foo(42, 'test', 3.14)
spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)
assert len(spans) == 1
assert spans[0]['attributes']['a'] == 42
assert spans[0]['attributes']['c'] == 3.14


@dataclass
class Foo:
x: int
Expand Down
4 changes: 4 additions & 0 deletions tests/test_logfire_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def func() -> None: ...
pass
logfire__all__.remove('attach_context')

assert hasattr(logfire_api, 'current_span')
logfire_api.current_span()
logfire__all__.remove('current_span')

# If it's not empty, it means that some of the __all__ members are not tested.
assert logfire__all__ == set(), logfire__all__

Expand Down
Loading