From 33a7b634d7a7a47a2ca37db7e94fa888daa8e101 Mon Sep 17 00:00:00 2001 From: Olivia Bahr Date: Fri, 19 Sep 2025 16:23:59 -0600 Subject: [PATCH 1/3] param version --- .../add-manual-tracing.md | 11 ++++ logfire/_internal/instrument.py | 29 +++++++-- tests/test_logfire.py | 60 +++++++++++++++++++ 3 files changed, 95 insertions(+), 5 deletions(-) diff --git a/docs/guides/onboarding-checklist/add-manual-tracing.md b/docs/guides/onboarding-checklist/add-manual-tracing.md index 8a0b62c8d..96a946010 100644 --- a/docs/guides/onboarding-checklist/add-manual-tracing.md +++ b/docs/guides/onboarding-checklist/add-manual-tracing.md @@ -234,6 +234,17 @@ my_function(3, 4) # Logs: Applying my_function to x=3 and y=4 ``` +You can also access the span directly within your instrumented function by adding a `logfire_span` parameter: + +```python +@logfire.instrument('Processing {x=}') +def process_data(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: + # Access and modify the span directly + if logfire_span: + logfire_span.message = f'Custom message for x={x}' + return x * 2 +``` + !!! note - The [`@logfire.instrument`][logfire.Logfire.instrument] decorator MUST be applied first, i.e., UNDER any other decorators. diff --git a/logfire/_internal/instrument.py b/logfire/_internal/instrument.py index 3d2b3fd98..430d1a124 100644 --- a/logfire/_internal/instrument.py +++ b/logfire/_internal/instrument.py @@ -63,19 +63,27 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: attributes = get_attributes(func, msg_template, tags) open_span = get_open_span(logfire, attributes, span_name, extract_args, func) + # Check if function has logfire_span parameter + sig = inspect.signature(func) + has_logfire_span_param = 'logfire_span' in sig.parameters + if inspect.isgeneratorfunction(func): if not allow_generator: warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2) def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore - with open_span(*func_args, **func_kwargs): + with open_span(*func_args, **func_kwargs) as span: + if has_logfire_span_param: + func_kwargs['logfire_span'] = span yield from func(*func_args, **func_kwargs) elif inspect.isasyncgenfunction(func): if not allow_generator: warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2) async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore - with open_span(*func_args, **func_kwargs): + with open_span(*func_args, **func_kwargs) as span: + if has_logfire_span_param: + func_kwargs['logfire_span'] = span # `yield from` is invalid syntax in an async function. # This loop is not quite equivalent, because `yield from` also handles things like # sending values to the subgenerator. @@ -90,6 +98,8 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore with open_span(*func_args, **func_kwargs) as span: + if has_logfire_span_param: + func_kwargs['logfire_span'] = span result = await func(*func_args, **func_kwargs) if record_return: # open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types. @@ -102,6 +112,8 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ig # Same as the above, but without the async/await def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: with open_span(*func_args, **func_kwargs) as span: + if has_logfire_span_param: + func_kwargs['logfire_span'] = span result = func(*func_args, **func_kwargs) if record_return: set_user_attributes_on_raw_span(span._span, {'return': result}) @@ -122,18 +134,25 @@ def get_open_span( ) -> Callable[P, AbstractContextManager[Any]]: final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore + # Check if function has logfire_span parameter + sig = inspect.signature(func) + has_logfire_span_param = 'logfire_span' in sig.parameters + # This is the fast case for when there are no arguments to extract def open_span(*_: P.args, **__: P.kwargs): # type: ignore + if has_logfire_span_param: + return logfire._span(final_span_name, attributes) # type: ignore return logfire._fast_span(final_span_name, attributes) # type: ignore if extract_args is True: - sig = inspect.signature(func) if sig.parameters: # only extract args if there are any def open_span(*func_args: P.args, **func_kwargs: P.kwargs): bound = sig.bind(*func_args, **func_kwargs) bound.apply_defaults() args_dict = bound.arguments + if has_logfire_span_param: + return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) @@ -141,8 +160,6 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): return open_span if extract_args: # i.e. extract_args should be an iterable of argument names - sig = inspect.signature(func) - if isinstance(extract_args, str): extract_args = [extract_args] @@ -165,6 +182,8 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): # This line is the only difference from the extract_args=True case args_dict = {k: args_dict[k] for k in extract_args_final} + if has_logfire_span_param: + return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) diff --git a/tests/test_logfire.py b/tests/test_logfire.py index a8647b5e3..583475d9e 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -1273,6 +1273,66 @@ def run(a: str) -> Model: ) +def test_instrument_with_logfire_span_parameter(exporter: TestExporter): + @logfire.instrument('Calling foo with {x=}') + def foo(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: + # Test that we can access the span and modify its message + assert logfire_span is not None + logfire_span.message = f'Modified message for x={x}' + return x * 2 + + result = foo(5) + assert result == 10 + + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + assert len(spans) == 1 + span = spans[0] + assert span['attributes']['logfire.msg'] == 'Modified message for x=5' + assert span['attributes']['x'] == 5 + + +def test_instrument_with_logfire_span_parameter_async(exporter: TestExporter): + @logfire.instrument('Calling async foo with {x=}') + async def foo(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: + # Test that we can access the span and modify its message + assert logfire_span is not None + logfire_span.message = f'Async modified message for x={x}' + return x * 3 + + async def run_test(): + return await foo(7) + + import asyncio + + result = asyncio.run(run_test()) + assert result == 21 + + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + assert len(spans) == 1 + span = spans[0] + assert span['attributes']['logfire.msg'] == 'Async modified message for x=7' + assert span['attributes']['x'] == 7 + + +def test_instrument_with_logfire_span_parameter_extract_args_false(exporter: TestExporter): + @logfire.instrument('Calling foo', extract_args=False) + def foo(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: + # Test that we can access the span and modify its message + assert logfire_span is not None + logfire_span.message = f'Extract args false message for x={x}' + return x * 4 + + result = foo(3) + assert result == 12 + + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + assert len(spans) == 1 + span = spans[0] + assert span['attributes']['logfire.msg'] == 'Extract args false message for x=3' + # x should not be in attributes since extract_args=False + assert 'x' not in span['attributes'] + + def test_validation_error_on_span(exporter: TestExporter) -> None: class Model(BaseModel, plugin_settings={'logfire': {'record': 'off'}}): a: int From 15b4b026272502b5578ed0eccc0d0ec8963fc51c Mon Sep 17 00:00:00 2001 From: Olivia Bahr Date: Fri, 19 Sep 2025 16:45:20 -0600 Subject: [PATCH 2/3] Revert "param version" This reverts commit 33a7b634d7a7a47a2ca37db7e94fa888daa8e101. --- .../add-manual-tracing.md | 11 ---- logfire/_internal/instrument.py | 29 ++------- tests/test_logfire.py | 60 ------------------- 3 files changed, 5 insertions(+), 95 deletions(-) diff --git a/docs/guides/onboarding-checklist/add-manual-tracing.md b/docs/guides/onboarding-checklist/add-manual-tracing.md index 96a946010..8a0b62c8d 100644 --- a/docs/guides/onboarding-checklist/add-manual-tracing.md +++ b/docs/guides/onboarding-checklist/add-manual-tracing.md @@ -234,17 +234,6 @@ my_function(3, 4) # Logs: Applying my_function to x=3 and y=4 ``` -You can also access the span directly within your instrumented function by adding a `logfire_span` parameter: - -```python -@logfire.instrument('Processing {x=}') -def process_data(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: - # Access and modify the span directly - if logfire_span: - logfire_span.message = f'Custom message for x={x}' - return x * 2 -``` - !!! note - The [`@logfire.instrument`][logfire.Logfire.instrument] decorator MUST be applied first, i.e., UNDER any other decorators. diff --git a/logfire/_internal/instrument.py b/logfire/_internal/instrument.py index 430d1a124..3d2b3fd98 100644 --- a/logfire/_internal/instrument.py +++ b/logfire/_internal/instrument.py @@ -63,27 +63,19 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: attributes = get_attributes(func, msg_template, tags) open_span = get_open_span(logfire, attributes, span_name, extract_args, func) - # Check if function has logfire_span parameter - sig = inspect.signature(func) - has_logfire_span_param = 'logfire_span' in sig.parameters - if inspect.isgeneratorfunction(func): if not allow_generator: warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2) def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore - with open_span(*func_args, **func_kwargs) as span: - if has_logfire_span_param: - func_kwargs['logfire_span'] = span + with open_span(*func_args, **func_kwargs): yield from func(*func_args, **func_kwargs) elif inspect.isasyncgenfunction(func): if not allow_generator: warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2) async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore - with open_span(*func_args, **func_kwargs) as span: - if has_logfire_span_param: - func_kwargs['logfire_span'] = span + with open_span(*func_args, **func_kwargs): # `yield from` is invalid syntax in an async function. # This loop is not quite equivalent, because `yield from` also handles things like # sending values to the subgenerator. @@ -98,8 +90,6 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore with open_span(*func_args, **func_kwargs) as span: - if has_logfire_span_param: - func_kwargs['logfire_span'] = span result = await func(*func_args, **func_kwargs) if record_return: # open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types. @@ -112,8 +102,6 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ig # Same as the above, but without the async/await def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: with open_span(*func_args, **func_kwargs) as span: - if has_logfire_span_param: - func_kwargs['logfire_span'] = span result = func(*func_args, **func_kwargs) if record_return: set_user_attributes_on_raw_span(span._span, {'return': result}) @@ -134,25 +122,18 @@ def get_open_span( ) -> Callable[P, AbstractContextManager[Any]]: final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore - # Check if function has logfire_span parameter - sig = inspect.signature(func) - has_logfire_span_param = 'logfire_span' in sig.parameters - # This is the fast case for when there are no arguments to extract def open_span(*_: P.args, **__: P.kwargs): # type: ignore - if has_logfire_span_param: - return logfire._span(final_span_name, attributes) # type: ignore return logfire._fast_span(final_span_name, attributes) # type: ignore if extract_args is True: + sig = inspect.signature(func) if sig.parameters: # only extract args if there are any def open_span(*func_args: P.args, **func_kwargs: P.kwargs): bound = sig.bind(*func_args, **func_kwargs) bound.apply_defaults() args_dict = bound.arguments - if has_logfire_span_param: - return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) @@ -160,6 +141,8 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): return open_span if extract_args: # i.e. extract_args should be an iterable of argument names + sig = inspect.signature(func) + if isinstance(extract_args, str): extract_args = [extract_args] @@ -182,8 +165,6 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): # This line is the only difference from the extract_args=True case args_dict = {k: args_dict[k] for k in extract_args_final} - if has_logfire_span_param: - return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 583475d9e..a8647b5e3 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -1273,66 +1273,6 @@ def run(a: str) -> Model: ) -def test_instrument_with_logfire_span_parameter(exporter: TestExporter): - @logfire.instrument('Calling foo with {x=}') - def foo(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: - # Test that we can access the span and modify its message - assert logfire_span is not None - logfire_span.message = f'Modified message for x={x}' - return x * 2 - - result = foo(5) - assert result == 10 - - spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) - assert len(spans) == 1 - span = spans[0] - assert span['attributes']['logfire.msg'] == 'Modified message for x=5' - assert span['attributes']['x'] == 5 - - -def test_instrument_with_logfire_span_parameter_async(exporter: TestExporter): - @logfire.instrument('Calling async foo with {x=}') - async def foo(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: - # Test that we can access the span and modify its message - assert logfire_span is not None - logfire_span.message = f'Async modified message for x={x}' - return x * 3 - - async def run_test(): - return await foo(7) - - import asyncio - - result = asyncio.run(run_test()) - assert result == 21 - - spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) - assert len(spans) == 1 - span = spans[0] - assert span['attributes']['logfire.msg'] == 'Async modified message for x=7' - assert span['attributes']['x'] == 7 - - -def test_instrument_with_logfire_span_parameter_extract_args_false(exporter: TestExporter): - @logfire.instrument('Calling foo', extract_args=False) - def foo(x: int, logfire_span: logfire.LogfireSpan | None = None) -> int: - # Test that we can access the span and modify its message - assert logfire_span is not None - logfire_span.message = f'Extract args false message for x={x}' - return x * 4 - - result = foo(3) - assert result == 12 - - spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) - assert len(spans) == 1 - span = spans[0] - assert span['attributes']['logfire.msg'] == 'Extract args false message for x=3' - # x should not be in attributes since extract_args=False - assert 'x' not in span['attributes'] - - def test_validation_error_on_span(exporter: TestExporter) -> None: class Model(BaseModel, plugin_settings={'logfire': {'record': 'off'}}): a: int From 1efd8499face6c6867293b4c856581f16bbe55d3 Mon Sep 17 00:00:00 2001 From: Olivia Bahr Date: Sun, 21 Sep 2025 20:46:23 -0600 Subject: [PATCH 3/3] context var version --- .../add-manual-tracing.md | 9 ++ logfire-api/logfire_api/__init__.py | 4 + logfire-api/logfire_api/__init__.pyi | 3 +- logfire-api/logfire_api/_internal/main.pyi | 4 + logfire/__init__.py | 2 + logfire/_internal/ast_utils.py | 20 ++++ logfire/_internal/instrument.py | 23 +++- logfire/_internal/main.py | 6 +- tests/test_logfire.py | 107 +++++++++++++++++- tests/test_logfire_api.py | 4 + 10 files changed, 178 insertions(+), 4 deletions(-) diff --git a/docs/guides/onboarding-checklist/add-manual-tracing.md b/docs/guides/onboarding-checklist/add-manual-tracing.md index 8a0b62c8d..f69192dcc 100644 --- a/docs/guides/onboarding-checklist/add-manual-tracing.md +++ b/docs/guides/onboarding-checklist/add-manual-tracing.md @@ -234,6 +234,15 @@ my_function(3, 4) # Logs: Applying my_function to x=3 and y=4 ``` +Access the created span inside the function using `logfire.current_span()`: + +```python +@logfire.instrument('Processing') +def process_data(user_dict: dict): + user_id = user_dict.get("id") + logfire.current_span().message = f'Processing User: {user_id}' +``` + !!! note - The [`@logfire.instrument`][logfire.Logfire.instrument] decorator MUST be applied first, i.e., UNDER any other decorators. diff --git a/logfire-api/logfire_api/__init__.py b/logfire-api/logfire_api/__init__.py index 4509704d2..62c90af62 100644 --- a/logfire-api/logfire_api/__init__.py +++ b/logfire-api/logfire_api/__init__.py @@ -96,6 +96,9 @@ def __init__(self, *args, **kwargs) -> None: ... def span(self, *args, **kwargs) -> LogfireSpan: return LogfireSpan() + def current_span(self) -> LogfireSpan: + return LogfireSpan() + def log(self, *args, **kwargs) -> None: ... def trace(self, *args, **kwargs) -> None: ... @@ -201,6 +204,7 @@ def shutdown(self, *args, **kwargs) -> None: ... DEFAULT_LOGFIRE_INSTANCE = Logfire() span = DEFAULT_LOGFIRE_INSTANCE.span + current_span = DEFAULT_LOGFIRE_INSTANCE.current_span log = DEFAULT_LOGFIRE_INSTANCE.log trace = DEFAULT_LOGFIRE_INSTANCE.trace debug = DEFAULT_LOGFIRE_INSTANCE.debug diff --git a/logfire-api/logfire_api/__init__.pyi b/logfire-api/logfire_api/__init__.pyi index fe9eba578..31bc99e8c 100644 --- a/logfire-api/logfire_api/__init__.pyi +++ b/logfire-api/logfire_api/__init__.pyi @@ -15,10 +15,11 @@ from logfire.propagate import attach_context as attach_context, get_context as g from logfire.sampling import SamplingOptions as SamplingOptions from typing import Any -__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'AdvancedOptions', 'ConsoleOptions', 'CodeSource', 'PydanticPlugin', 'configure', 'span', 'instrument', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'warning', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_asgi', 'instrument_wsgi', 'instrument_pydantic', 'instrument_pydantic_ai', 'instrument_fastapi', 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', 'instrument_google_genai', 'instrument_litellm', 'instrument_print', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_aiohttp_server', 'instrument_sqlalchemy', 'instrument_sqlite3', 'instrument_aws_lambda', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'instrument_mcp', 'AutoTraceModule', 'with_tags', 'with_settings', 'suppress_scopes', 'shutdown', 'no_auto_trace', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'add_non_user_code_prefix', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'loguru_handler', 'SamplingOptions', 'MetricsOptions', 'logfire_info', 'get_baggage', 'set_baggage', 'get_context', 'attach_context'] +__all__ = ['Logfire', 'LogfireSpan', 'LevelName', 'AdvancedOptions', 'ConsoleOptions', 'CodeSource', 'PydanticPlugin', 'configure', 'span', 'instrument', 'current_span', 'log', 'trace', 'debug', 'notice', 'info', 'warn', 'warning', 'error', 'exception', 'fatal', 'force_flush', 'log_slow_async_callbacks', 'install_auto_tracing', 'instrument_asgi', 'instrument_wsgi', 'instrument_pydantic', 'instrument_pydantic_ai', 'instrument_fastapi', 'instrument_openai', 'instrument_openai_agents', 'instrument_anthropic', 'instrument_google_genai', 'instrument_litellm', 'instrument_print', 'instrument_asyncpg', 'instrument_httpx', 'instrument_celery', 'instrument_requests', 'instrument_psycopg', 'instrument_django', 'instrument_flask', 'instrument_starlette', 'instrument_aiohttp_client', 'instrument_aiohttp_server', 'instrument_sqlalchemy', 'instrument_sqlite3', 'instrument_aws_lambda', 'instrument_redis', 'instrument_pymongo', 'instrument_mysql', 'instrument_system_metrics', 'instrument_mcp', 'AutoTraceModule', 'with_tags', 'with_settings', 'suppress_scopes', 'shutdown', 'no_auto_trace', 'ScrubMatch', 'ScrubbingOptions', 'VERSION', 'add_non_user_code_prefix', 'suppress_instrumentation', 'StructlogProcessor', 'LogfireLoggingHandler', 'loguru_handler', 'SamplingOptions', 'MetricsOptions', 'logfire_info', 'get_baggage', 'set_baggage', 'get_context', 'attach_context'] DEFAULT_LOGFIRE_INSTANCE = Logfire() span = DEFAULT_LOGFIRE_INSTANCE.span +current_span = DEFAULT_LOGFIRE_INSTANCE.current_span instrument = DEFAULT_LOGFIRE_INSTANCE.instrument force_flush = DEFAULT_LOGFIRE_INSTANCE.force_flush log_slow_async_callbacks = DEFAULT_LOGFIRE_INSTANCE.log_slow_async_callbacks diff --git a/logfire-api/logfire_api/_internal/main.pyi b/logfire-api/logfire_api/_internal/main.pyi index 5cbe73a12..6204f7449 100644 --- a/logfire-api/logfire_api/_internal/main.pyi +++ b/logfire-api/logfire_api/_internal/main.pyi @@ -237,6 +237,10 @@ class Logfire: attributes: The arguments to include in the span and format the message template with. Attributes starting with an underscore are not allowed. """ + def current_span(self) -> LogfireSpan: + """Get the current span. + Useful for accessing spans created by @logfire.instrument. + """ @overload def instrument(self, msg_template: LiteralString | None = None, *, span_name: str | None = None, extract_args: bool | Iterable[str] = True, record_return: bool = False, allow_generator: bool = False) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator for instrumenting a function as a span. diff --git a/logfire/__init__.py b/logfire/__init__.py index 2badb3b4d..6c7c7fc06 100644 --- a/logfire/__init__.py +++ b/logfire/__init__.py @@ -24,6 +24,7 @@ DEFAULT_LOGFIRE_INSTANCE: Logfire = Logfire() span = DEFAULT_LOGFIRE_INSTANCE.span instrument = DEFAULT_LOGFIRE_INSTANCE.instrument +current_span = DEFAULT_LOGFIRE_INSTANCE.current_span force_flush = DEFAULT_LOGFIRE_INSTANCE.force_flush log_slow_async_callbacks = DEFAULT_LOGFIRE_INSTANCE.log_slow_async_callbacks install_auto_tracing = DEFAULT_LOGFIRE_INSTANCE.install_auto_tracing @@ -108,6 +109,7 @@ def loguru_handler() -> Any: 'configure', 'span', 'instrument', + 'current_span', 'log', 'trace', 'debug', diff --git a/logfire/_internal/ast_utils.py b/logfire/_internal/ast_utils.py index 22ed9b6c2..c37608f48 100644 --- a/logfire/_internal/ast_utils.py +++ b/logfire/_internal/ast_utils.py @@ -4,6 +4,7 @@ import functools import inspect import sys +import textwrap import types import warnings from abc import ABC, abstractmethod @@ -245,6 +246,25 @@ class InspectArgumentsFailedWarning(Warning): pass +@functools.lru_cache(maxsize=1024) +def has_current_span_call(func: Any) -> bool: + """Check if a function contains calls to logfire.current_span().""" + try: + tree = ast.parse(textwrap.dedent(inspect.getsource(func))) + for node in ast.walk(tree): + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.attr == 'current_span' + and node.func.value.id == 'logfire' + ): + return True + return False + except (OSError, SyntaxError, TypeError, AttributeError): + return False + + @functools.lru_cache(maxsize=1024) def get_node_source_text(node: ast.AST, ex_source: executing.Source): """Returns some Python source code representing `node`. diff --git a/logfire/_internal/instrument.py b/logfire/_internal/instrument.py index 3d2b3fd98..253b9aa65 100644 --- a/logfire/_internal/instrument.py +++ b/logfire/_internal/instrument.py @@ -11,6 +11,7 @@ from opentelemetry.util import types as otel_types from typing_extensions import LiteralString, ParamSpec +from .ast_utils import has_current_span_call from .constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY, ATTRIBUTES_TAGS_KEY from .stack_info import get_filepath_attribute from .utils import safe_repr, uniquify_sequence @@ -61,7 +62,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: ) attributes = get_attributes(func, msg_template, tags) - open_span = get_open_span(logfire, attributes, span_name, extract_args, func) + uses_current_span = has_current_span_call(func) + open_span = get_open_span(logfire, attributes, span_name, extract_args, uses_current_span, func) if inspect.isgeneratorfunction(func): if not allow_generator: @@ -90,6 +92,9 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore with open_span(*func_args, **func_kwargs) as span: + token = None + if uses_current_span: + token = logfire._current_span_var.set(span) # type: ignore[protected-access] result = await func(*func_args, **func_kwargs) if record_return: # open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types. @@ -97,14 +102,21 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ig # Not sure if making get_open_span return a LogfireSpan when record_return is True # would be faster overall or if it would be worth the added complexity. set_user_attributes_on_raw_span(span._span, {'return': result}) + if token: + logfire._current_span_var.reset(token) # type: ignore[protected-access] return result else: # Same as the above, but without the async/await def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: with open_span(*func_args, **func_kwargs) as span: + token = None + if uses_current_span: + token = logfire._current_span_var.set(span) # type: ignore[protected-access] result = func(*func_args, **func_kwargs) if record_return: set_user_attributes_on_raw_span(span._span, {'return': result}) + if token: + logfire._current_span_var.reset(token) # type: ignore[protected-access] return result wrapper = functools.wraps(func)(wrapper) # type: ignore @@ -118,12 +130,15 @@ def get_open_span( attributes: dict[str, otel_types.AttributeValue], span_name: str | None, extract_args: bool | Iterable[str], + uses_current_span: bool, func: Callable[P, R], ) -> Callable[P, AbstractContextManager[Any]]: final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore # This is the fast case for when there are no arguments to extract def open_span(*_: P.args, **__: P.kwargs): # type: ignore + if uses_current_span: + return logfire._span(final_span_name, attributes) # type: ignore[protected-access] return logfire._fast_span(final_span_name, attributes) # type: ignore if extract_args is True: @@ -134,6 +149,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): bound = sig.bind(*func_args, **func_kwargs) bound.apply_defaults() args_dict = bound.arguments + if uses_current_span: + return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access] + return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) @@ -165,6 +183,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): # This line is the only difference from the extract_args=True case args_dict = {k: args_dict[k] for k in extract_args_final} + if uses_current_span: + return logfire._span(final_span_name, {**attributes, **args_dict}) # type: ignore[protected-access] + return logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 120499f2f..559cbbcdf 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -7,7 +7,7 @@ import warnings from collections.abc import Iterable, Sequence from contextlib import AbstractContextManager -from contextvars import Token +from contextvars import ContextVar, Token from enum import Enum from functools import cached_property from time import time @@ -138,6 +138,7 @@ def __init__( self._sample_rate = sample_rate self._console_log = console_log self._otel_scope = otel_scope + self._current_span_var = ContextVar('logfire_current_span', default=NoopSpan()) @property def config(self) -> LogfireConfig: @@ -171,6 +172,9 @@ def _get_tracer(self, *, is_span_tracer: bool) -> Tracer: # pragma: no cover is_span_tracer=is_span_tracer, ) + def current_span(self) -> LogfireSpan: + return self._current_span_var.get() # type: ignore[return-value] + # If any changes are made to this method, they may need to be reflected in `_fast_span` as well. def _span( self, diff --git a/tests/test_logfire.py b/tests/test_logfire.py index a8647b5e3..3ba5cce03 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -38,7 +38,7 @@ LevelName, ) from logfire._internal.formatter import FormattingFailedWarning -from logfire._internal.main import NoopSpan +from logfire._internal.main import LogfireSpan, NoopSpan from logfire._internal.tracer import record_exception from logfire._internal.utils import SeededRandomIdGenerator, is_instrumentation_suppressed from logfire.integrations.logging import LogfireLoggingHandler @@ -1344,6 +1344,111 @@ def run(a: str) -> None: ) +def test_current_span_default_is_noop(): + span = logfire.DEFAULT_LOGFIRE_INSTANCE.current_span() + assert isinstance(span, NoopSpan) + + +def test_current_span_nested(exporter: TestExporter): + @logfire.instrument('outer') + def outer(): + s = logfire.current_span() + s.message = 'Starting outer operation' + inner() + assert logfire.current_span() == s + s = logfire.current_span() + s.message = 'Completing outer operation' + + @logfire.instrument('inner') + def inner(): + s = logfire.current_span() + s.message = 'Processing inner operation' + + outer() + assert isinstance(logfire.current_span(), NoopSpan) + + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + + assert len(spans) == 2 + assert spans[0]['name'] == 'inner' + assert spans[0]['attributes']['logfire.msg'] == 'Processing inner operation' + assert spans[1]['name'] == 'outer' + assert spans[1]['attributes']['logfire.msg'] == 'Completing outer operation' + + +@pytest.mark.anyio +async def test_current_span_async(exporter: TestExporter): + print('Testing current_span_async - functions should auto-detect current_span usage') + + @logfire.instrument('async outer') + async def outer(): + s = logfire.current_span() + assert isinstance(s, LogfireSpan) + s.message = 'Starting async outer operation' + await inner() + s = logfire.current_span() + assert isinstance(s, LogfireSpan) + s.message = 'Completing async outer operation' + + @logfire.instrument('async inner') + async def inner(): + s = logfire.current_span() + assert isinstance(s, LogfireSpan) + s.message = 'Processing async inner operation' + + await outer() + assert isinstance(logfire.current_span(), NoopSpan) + + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + + assert len(spans) == 2 + assert spans[0]['name'] == 'async inner' + assert spans[0]['attributes']['logfire.msg'] == 'Processing async inner operation' + assert spans[1]['name'] == 'async outer' + assert spans[1]['attributes']['logfire.msg'] == 'Completing async outer operation' + + +def test_fast_span_when_current_span_not_called(exporter: TestExporter): + """Test that when current_span() is not called, the exported span is a fast span""" + + @logfire.instrument + def fast_operation(): ... + + fast_operation() + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + + # pending_span would be a LogfireSpan, just span is FastLogfireSpan + assert spans[0]['attributes']['logfire.span_type'] == 'span' + + +def test_instrument_extract_args_true_with_current_span(exporter: TestExporter): + """Test instrument with extract_args=True and current_span() usage.""" + + @logfire.instrument(extract_args=True, record_return=True) + def foo(a: int, b: str) -> None: + logfire.current_span() + + foo(42, 'test') + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + assert len(spans) == 1 + assert spans[0]['attributes']['a'] == 42 + assert spans[0]['attributes']['b'] == 'test' + + +def test_instrument_extract_args_list_with_current_span(exporter: TestExporter): + """Test instrument with extract_args list and current_span() usage.""" + + @logfire.instrument(extract_args=['a', 'c']) + def foo(a: int, b: str, c: float) -> None: + logfire.current_span() + + foo(42, 'test', 3.14) + spans = exporter.exported_spans_as_dict(_strip_function_qualname=False) + assert len(spans) == 1 + assert spans[0]['attributes']['a'] == 42 + assert spans[0]['attributes']['c'] == 3.14 + + @dataclass class Foo: x: int diff --git a/tests/test_logfire_api.py b/tests/test_logfire_api.py index fd07900f8..7920db480 100644 --- a/tests/test_logfire_api.py +++ b/tests/test_logfire_api.py @@ -273,6 +273,10 @@ def func() -> None: ... pass logfire__all__.remove('attach_context') + assert hasattr(logfire_api, 'current_span') + logfire_api.current_span() + logfire__all__.remove('current_span') + # If it's not empty, it means that some of the __all__ members are not tested. assert logfire__all__ == set(), logfire__all__