Skip to content

Commit 3fea406

Browse files
committed
context var version
1 parent 15b4b02 commit 3fea406

File tree

6 files changed

+136
-3
lines changed

6 files changed

+136
-3
lines changed

docs/guides/onboarding-checklist/add-manual-tracing.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@ my_function(3, 4)
234234
# Logs: Applying my_function to x=3 and y=4
235235
```
236236

237+
Access the created span inside the function using `logfire.current_span()`:
238+
239+
```python
240+
@logfire.instrument('Processing')
241+
def process_data(user_dict: dict):
242+
user_id = user_dict.get("id")
243+
logfire.current_span().message = f'Processing User: {user_id}'
244+
```
245+
237246
!!! note
238247

239248
- The [`@logfire.instrument`][logfire.Logfire.instrument] decorator MUST be applied first, i.e., UNDER any other decorators.

logfire/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
DEFAULT_LOGFIRE_INSTANCE: Logfire = Logfire()
2525
span = DEFAULT_LOGFIRE_INSTANCE.span
2626
instrument = DEFAULT_LOGFIRE_INSTANCE.instrument
27+
current_span = DEFAULT_LOGFIRE_INSTANCE.current_span
2728
force_flush = DEFAULT_LOGFIRE_INSTANCE.force_flush
2829
log_slow_async_callbacks = DEFAULT_LOGFIRE_INSTANCE.log_slow_async_callbacks
2930
install_auto_tracing = DEFAULT_LOGFIRE_INSTANCE.install_auto_tracing
@@ -108,6 +109,7 @@ def loguru_handler() -> Any:
108109
'configure',
109110
'span',
110111
'instrument',
112+
'current_span',
111113
'log',
112114
'trace',
113115
'debug',

logfire/_internal/ast_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import inspect
66
import sys
7+
import textwrap
78
import types
89
import warnings
910
from abc import ABC, abstractmethod
@@ -245,6 +246,25 @@ class InspectArgumentsFailedWarning(Warning):
245246
pass
246247

247248

249+
@functools.lru_cache(maxsize=1024)
250+
def has_current_span_call(func: Any) -> bool:
251+
"""Check if a function contains calls to logfire.current_span()."""
252+
try:
253+
tree = ast.parse(textwrap.dedent(inspect.getsource(func)))
254+
for node in ast.walk(tree):
255+
if (
256+
isinstance(node, ast.Call)
257+
and isinstance(node.func, ast.Attribute)
258+
and isinstance(node.func.value, ast.Name)
259+
and node.func.attr == 'current_span'
260+
and node.func.value.id == 'logfire'
261+
):
262+
return True
263+
return False
264+
except (OSError, SyntaxError, TypeError, AttributeError):
265+
return False
266+
267+
248268
@functools.lru_cache(maxsize=1024)
249269
def get_node_source_text(node: ast.AST, ex_source: executing.Source):
250270
"""Returns some Python source code representing `node`.

logfire/_internal/instrument.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from opentelemetry.util import types as otel_types
1212
from typing_extensions import LiteralString, ParamSpec
1313

14+
from .ast_utils import has_current_span_call
1415
from .constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY, ATTRIBUTES_TAGS_KEY
1516
from .stack_info import get_filepath_attribute
1617
from .utils import safe_repr, uniquify_sequence
@@ -61,7 +62,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
6162
)
6263

6364
attributes = get_attributes(func, msg_template, tags)
64-
open_span = get_open_span(logfire, attributes, span_name, extract_args, func)
65+
uses_current_span = has_current_span_call(func)
66+
open_span = get_open_span(logfire, attributes, span_name, extract_args, uses_current_span, func)
6567

6668
if inspect.isgeneratorfunction(func):
6769
if not allow_generator:
@@ -90,21 +92,31 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore
9092

9193
async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore
9294
with open_span(*func_args, **func_kwargs) as span:
95+
token = None
96+
if uses_current_span:
97+
token = logfire._current_span_var.set(span) # type: ignore[protected-access]
9398
result = await func(*func_args, **func_kwargs)
9499
if record_return:
95100
# open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types.
96101
# This isn't great because it has to parse the JSON schema.
97102
# Not sure if making get_open_span return a LogfireSpan when record_return is True
98103
# would be faster overall or if it would be worth the added complexity.
99104
set_user_attributes_on_raw_span(span._span, {'return': result})
105+
if token:
106+
logfire._current_span_var.reset(token) # type: ignore[protected-access]
100107
return result
101108
else:
102109
# Same as the above, but without the async/await
103110
def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R:
104111
with open_span(*func_args, **func_kwargs) as span:
112+
token = None
113+
if uses_current_span:
114+
token = logfire._current_span_var.set(span) # type: ignore[protected-access]
105115
result = func(*func_args, **func_kwargs)
106116
if record_return:
107117
set_user_attributes_on_raw_span(span._span, {'return': result})
118+
if token:
119+
logfire._current_span_var.reset(token) # type: ignore[protected-access]
108120
return result
109121

110122
wrapper = functools.wraps(func)(wrapper) # type: ignore
@@ -118,12 +130,15 @@ def get_open_span(
118130
attributes: dict[str, otel_types.AttributeValue],
119131
span_name: str | None,
120132
extract_args: bool | Iterable[str],
133+
uses_current_span: bool,
121134
func: Callable[P, R],
122135
) -> Callable[P, AbstractContextManager[Any]]:
123136
final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore
124137

125138
# This is the fast case for when there are no arguments to extract
126139
def open_span(*_: P.args, **__: P.kwargs): # type: ignore
140+
if uses_current_span:
141+
return logfire._span(final_span_name, attributes) # type: ignore[protected-access]
127142
return logfire._fast_span(final_span_name, attributes) # type: ignore
128143

129144
if extract_args is True:
@@ -134,6 +149,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
134149
bound = sig.bind(*func_args, **func_kwargs)
135150
bound.apply_defaults()
136151
args_dict = bound.arguments
152+
if uses_current_span:
153+
return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access]
154+
137155
return logfire._instrument_span_with_args( # type: ignore
138156
final_span_name, attributes, args_dict
139157
)
@@ -165,6 +183,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
165183
# This line is the only difference from the extract_args=True case
166184
args_dict = {k: args_dict[k] for k in extract_args_final}
167185

186+
if uses_current_span:
187+
return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access]
188+
168189
return logfire._instrument_span_with_args( # type: ignore
169190
final_span_name, attributes, args_dict
170191
)

logfire/_internal/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import warnings
88
from collections.abc import Iterable, Sequence
99
from contextlib import AbstractContextManager
10-
from contextvars import Token
10+
from contextvars import ContextVar, Token
1111
from enum import Enum
1212
from functools import cached_property
1313
from time import time
@@ -138,6 +138,7 @@ def __init__(
138138
self._sample_rate = sample_rate
139139
self._console_log = console_log
140140
self._otel_scope = otel_scope
141+
self._current_span_var = ContextVar('logfire_current_span', default=NoopSpan())
141142

142143
@property
143144
def config(self) -> LogfireConfig:
@@ -171,6 +172,9 @@ def _get_tracer(self, *, is_span_tracer: bool) -> Tracer: # pragma: no cover
171172
is_span_tracer=is_span_tracer,
172173
)
173174

175+
def current_span(self) -> LogfireSpan:
176+
return self._current_span_var.get() # type: ignore[return-value]
177+
174178
# If any changes are made to this method, they may need to be reflected in `_fast_span` as well.
175179
def _span(
176180
self,

tests/test_logfire.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
LevelName,
3939
)
4040
from logfire._internal.formatter import FormattingFailedWarning
41-
from logfire._internal.main import NoopSpan
41+
from logfire._internal.main import LogfireSpan, NoopSpan
4242
from logfire._internal.tracer import record_exception
4343
from logfire._internal.utils import SeededRandomIdGenerator, is_instrumentation_suppressed
4444
from logfire.integrations.logging import LogfireLoggingHandler
@@ -1344,6 +1344,83 @@ def run(a: str) -> None:
13441344
)
13451345

13461346

1347+
def test_current_span_default_is_noop():
1348+
span = logfire.DEFAULT_LOGFIRE_INSTANCE.current_span()
1349+
assert isinstance(span, NoopSpan)
1350+
1351+
1352+
def test_current_span_nested(exporter: TestExporter):
1353+
@logfire.instrument('outer')
1354+
def outer():
1355+
s = logfire.current_span()
1356+
s.message = 'Starting outer operation'
1357+
inner()
1358+
assert logfire.current_span() == s
1359+
s = logfire.current_span()
1360+
s.message = 'Completing outer operation'
1361+
1362+
@logfire.instrument('inner')
1363+
def inner():
1364+
s = logfire.current_span()
1365+
s.message = 'Processing inner operation'
1366+
1367+
outer()
1368+
assert isinstance(logfire.current_span(), NoopSpan)
1369+
1370+
spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)
1371+
1372+
assert len(spans) == 2
1373+
assert spans[0]['name'] == 'inner'
1374+
assert spans[0]['attributes']['logfire.msg'] == 'Processing inner operation'
1375+
assert spans[1]['name'] == 'outer'
1376+
assert spans[1]['attributes']['logfire.msg'] == 'Completing outer operation'
1377+
1378+
1379+
@pytest.mark.anyio
1380+
async def test_current_span_async(exporter: TestExporter):
1381+
print('Testing current_span_async - functions should auto-detect current_span usage')
1382+
1383+
@logfire.instrument('async outer')
1384+
async def outer():
1385+
s = logfire.current_span()
1386+
assert isinstance(s, LogfireSpan)
1387+
s.message = 'Starting async outer operation'
1388+
await inner()
1389+
s = logfire.current_span()
1390+
assert isinstance(s, LogfireSpan)
1391+
s.message = 'Completing async outer operation'
1392+
1393+
@logfire.instrument('async inner')
1394+
async def inner():
1395+
s = logfire.current_span()
1396+
assert isinstance(s, LogfireSpan)
1397+
s.message = 'Processing async inner operation'
1398+
1399+
await outer()
1400+
assert isinstance(logfire.current_span(), NoopSpan)
1401+
1402+
spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)
1403+
1404+
assert len(spans) == 2
1405+
assert spans[0]['name'] == 'async inner'
1406+
assert spans[0]['attributes']['logfire.msg'] == 'Processing async inner operation'
1407+
assert spans[1]['name'] == 'async outer'
1408+
assert spans[1]['attributes']['logfire.msg'] == 'Completing async outer operation'
1409+
1410+
1411+
def test_fast_span_when_current_span_not_called(exporter: TestExporter):
1412+
"""Test that when current_span() is not called, the exported span is a fast span"""
1413+
1414+
@logfire.instrument
1415+
def fast_operation(): ...
1416+
1417+
fast_operation()
1418+
spans = exporter.exported_spans_as_dict(_strip_function_qualname=False)
1419+
1420+
# pending_span would be a LogfireSpan, just span is FastLogfireSpan
1421+
assert spans[0]['attributes']['logfire.span_type'] == 'span'
1422+
1423+
13471424
@dataclass
13481425
class Foo:
13491426
x: int

0 commit comments

Comments
 (0)