diff --git a/logfire/_internal/instrument.py b/logfire/_internal/instrument.py index 3d2b3fd98..acc06ffff 100644 --- a/logfire/_internal/instrument.py +++ b/logfire/_internal/instrument.py @@ -5,11 +5,12 @@ import inspect import warnings from collections.abc import Iterable, Sequence -from contextlib import AbstractContextManager, asynccontextmanager, contextmanager +from contextlib import AbstractContextManager, asynccontextmanager, contextmanager, nullcontext from typing import TYPE_CHECKING, Any, Callable, TypeVar +from opentelemetry import trace from opentelemetry.util import types as otel_types -from typing_extensions import LiteralString, ParamSpec +from typing_extensions import Concatenate, LiteralString, ParamSpec from .constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY, ATTRIBUTES_TAGS_KEY from .stack_info import get_filepath_attribute @@ -50,7 +51,10 @@ def instrument( extract_args: bool | Iterable[str], record_return: bool, allow_generator: bool, + new_context: bool, ) -> Callable[[Callable[P, R]], Callable[P, R]]: + from logfire.propagate import attach_context + from .main import set_user_attributes_on_raw_span def decorator(func: Callable[P, R]) -> Callable[P, R]: @@ -60,6 +64,13 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: stacklevel=2, ) + if new_context: + context_manager = attach_context({} if new_context is True else new_context()) + link_to_current = True + else: + link_to_current = False + context_manager = nullcontext() + attributes = get_attributes(func, msg_template, tags) open_span = get_open_span(logfire, attributes, span_name, extract_args, func) @@ -68,44 +79,52 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]: warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2) def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore - with open_span(*func_args, **func_kwargs): - yield from func(*func_args, **func_kwargs) + prev_context = trace.get_current_span().get_span_context() if link_to_current else None + with context_manager: + with open_span(prev_context, *func_args, **func_kwargs): + yield from func(*func_args, **func_kwargs) elif inspect.isasyncgenfunction(func): if not allow_generator: warnings.warn(GENERATOR_WARNING_MESSAGE, stacklevel=2) async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore - with open_span(*func_args, **func_kwargs): - # `yield from` is invalid syntax in an async function. - # This loop is not quite equivalent, because `yield from` also handles things like - # sending values to the subgenerator. - # Fixing this would at least mean porting https://peps.python.org/pep-0380/#formal-semantics - # which is quite messy, and it's not clear if that would be correct based on - # https://discuss.python.org/t/yield-from-in-async-functions/47050. - # So instead we have an extra warning in the docs about this. - async for x in func(*func_args, **func_kwargs): - yield x + prev_context = trace.get_current_span().get_span_context() if link_to_current else None + with context_manager: + with open_span(prev_context, *func_args, **func_kwargs): + # `yield from` is invalid syntax in an async function. + # This loop is not quite equivalent, because `yield from` also handles things like + # sending values to the subgenerator. + # Fixing this would at least mean porting https://peps.python.org/pep-0380/#formal-semantics + # which is quite messy, and it's not clear if that would be correct based on + # https://discuss.python.org/t/yield-from-in-async-functions/47050. + # So instead we have an extra warning in the docs about this. + async for x in func(*func_args, **func_kwargs): + yield x elif inspect.iscoroutinefunction(func): async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ignore - with open_span(*func_args, **func_kwargs) as span: - result = await func(*func_args, **func_kwargs) - if record_return: - # open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types. - # This isn't great because it has to parse the JSON schema. - # Not sure if making get_open_span return a LogfireSpan when record_return is True - # would be faster overall or if it would be worth the added complexity. - set_user_attributes_on_raw_span(span._span, {'return': result}) - return result + prev_context = trace.get_current_span().get_span_context() if link_to_current else None + with context_manager: + with open_span(prev_context, *func_args, **func_kwargs) as span: + result = await func(*func_args, **func_kwargs) + if record_return: + # open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types. + # This isn't great because it has to parse the JSON schema. + # Not sure if making get_open_span return a LogfireSpan when record_return is True + # would be faster overall or if it would be worth the added complexity. + set_user_attributes_on_raw_span(span._span, {'return': result}) + return result else: # Same as the above, but without the async/await def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: - with open_span(*func_args, **func_kwargs) as span: - result = func(*func_args, **func_kwargs) - if record_return: - set_user_attributes_on_raw_span(span._span, {'return': result}) - return result + prev_context = trace.get_current_span().get_span_context() if link_to_current else None + with context_manager: + with open_span(prev_context, *func_args, **func_kwargs) as span: + result = func(*func_args, **func_kwargs) + if record_return: + set_user_attributes_on_raw_span(span._span, {'return': result}) + return result wrapper = functools.wraps(func)(wrapper) # type: ignore return wrapper @@ -119,28 +138,32 @@ def get_open_span( span_name: str | None, extract_args: bool | Iterable[str], func: Callable[P, R], -) -> Callable[P, AbstractContextManager[Any]]: +) -> Callable[Concatenate[trace.SpanContext | None, P], AbstractContextManager[Any]]: final_span_name: str = span_name or attributes[ATTRIBUTES_MESSAGE_TEMPLATE_KEY] # type: ignore # This is the fast case for when there are no arguments to extract - def open_span(*_: P.args, **__: P.kwargs): # type: ignore - return logfire._fast_span(final_span_name, attributes) # type: ignore + def open_span(span_context: trace.SpanContext | None, *_: P.args, **__: P.kwargs): # type: ignore + span = logfire._fast_span(final_span_name, attributes) # type: ignore + if span_context is not None: + span._span.add_link(span_context) # pyright: ignore[reportPrivateUsage] + return span if extract_args is True: sig = inspect.signature(func) if sig.parameters: # only extract args if there are any - def open_span(*func_args: P.args, **func_kwargs: P.kwargs): + def open_span(span_context: trace.SpanContext | None, *func_args: P.args, **func_kwargs: P.kwargs): bound = sig.bind(*func_args, **func_kwargs) bound.apply_defaults() args_dict = bound.arguments - return logfire._instrument_span_with_args( # type: ignore + span = logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) + if span_context is not None: + span._span.add_link(span_context) # pyright: ignore[reportPrivateUsage] + return span - return open_span - - if extract_args: # i.e. extract_args should be an iterable of argument names + elif extract_args: # i.e. extract_args should be an iterable of argument names sig = inspect.signature(func) if isinstance(extract_args, str): @@ -157,7 +180,7 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): if extract_args_final: # check that there are still arguments to extract - def open_span(*func_args: P.args, **func_kwargs: P.kwargs): + def open_span(span_context: trace.SpanContext | None, *func_args: P.args, **func_kwargs: P.kwargs): bound = sig.bind(*func_args, **func_kwargs) bound.apply_defaults() args_dict = bound.arguments @@ -165,9 +188,12 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs): # This line is the only difference from the extract_args=True case args_dict = {k: args_dict[k] for k in extract_args_final} - return logfire._instrument_span_with_args( # type: ignore + span = logfire._instrument_span_with_args( # type: ignore final_span_name, attributes, args_dict ) + if span_context is not None: + span._span.add_link(span_context) # pyright: ignore[reportPrivateUsage] + return span return open_span diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 296eace68..049d0f080 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -580,6 +580,7 @@ def instrument( extract_args: bool | Iterable[str] = True, record_return: bool = False, allow_generator: bool = False, + new_context: bool = False, ) -> Callable[[Callable[P, R]], Callable[P, R]]: """Decorator for instrumenting a function as a span. @@ -603,6 +604,7 @@ def my_function(a: int): Ignored for generators. allow_generator: Set to `True` to prevent a warning when instrumenting a generator function. Read https://logfire.pydantic.dev/docs/guides/advanced/generators/#using-logfireinstrument first. + new_context: Set to `True` to clear context before starting instrumentation, and link back to the previous span. """ @overload @@ -629,6 +631,7 @@ def instrument( # type: ignore[reportInconsistentOverload] extract_args: bool | Iterable[str] = True, record_return: bool = False, allow_generator: bool = False, + new_context: bool = False, ) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]: """Decorator for instrumenting a function as a span. @@ -652,11 +655,12 @@ def my_function(a: int): Ignored for generators. allow_generator: Set to `True` to prevent a warning when instrumenting a generator function. Read https://logfire.pydantic.dev/docs/guides/advanced/generators/#using-logfireinstrument first. + new_context: Set to `True` to clear context before starting instrumentation, and link back to the previous span. """ if callable(msg_template): return self.instrument()(msg_template) return instrument( - self, tuple(self._tags), msg_template, span_name, extract_args, record_return, allow_generator + self, tuple(self._tags), msg_template, span_name, extract_args, record_return, allow_generator, new_context ) def log( diff --git a/tests/test_logfire.py b/tests/test_logfire.py index dd53ef44d..a4166a2a4 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -748,6 +748,437 @@ def hello_world(a: int) -> str: ) +def test_instrument_with_parent(exporter: TestExporter) -> None: + tagged = logfire.with_tags('test_instrument') + + @tagged.instrument('hello-world {a=}', record_return=True) + def hello_world(a: int) -> str: + return f'hello {a}' + + @tagged.instrument('parent', record_return=True) + def parent() -> str: + return hello_world(5) + + assert parent() == 'hello 5' + + assert exporter.exported_spans_as_dict(_include_pending_spans=True, _strip_function_qualname=False) == snapshot( + [ + { + 'name': 'parent', + 'context': {'trace_id': 1, 'span_id': 2, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'start_time': 1000000000, + 'end_time': 1000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_with_parent..parent', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'name': 'hello-world {a=}', + 'context': {'trace_id': 1, 'span_id': 4, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 3, 'is_remote': False}, + 'start_time': 2000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_with_parent..hello_world', + 'a': 5, + 'logfire.msg_template': 'hello-world {a=}', + 'logfire.msg': 'hello-world a=5', + 'logfire.json_schema': '{"type":"object","properties":{"a":{}}}', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000001', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'attributes': { + 'a': 5, + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_instrument_with_parent..hello_world', + 'code.lineno': 123, + 'logfire.json_schema': '{"type":"object","properties":{"a":{},"return":{}}}', + 'logfire.msg': 'hello-world a=5', + 'logfire.msg_template': 'hello-world {a=}', + 'logfire.span_type': 'span', + 'logfire.tags': ('test_instrument',), + 'return': 'hello 5', + }, + 'context': { + 'is_remote': False, + 'span_id': 3, + 'trace_id': 1, + }, + 'end_time': 3000000000, + 'name': 'hello-world {a=}', + 'parent': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + 'start_time': 2000000000, + }, + { + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_instrument_with_parent..parent', + 'code.lineno': 123, + 'logfire.json_schema': '{"type":"object","properties":{"return":{}}}', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'span', + 'logfire.tags': ('test_instrument',), + 'return': 'hello 5', + }, + 'context': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + 'end_time': 4000000000, + 'name': 'parent', + 'parent': None, + 'start_time': 1000000000, + }, + ] + ) + + +def test_instrument_new_context(exporter: TestExporter) -> None: + tagged = logfire.with_tags('test_instrument') + + @tagged.instrument('hello-world {a=}', record_return=True, new_context=True) + def hello_world(a: int) -> str: + return f'hello {a}' + + @tagged.instrument('parent', record_return=True) + def parent() -> str: + return hello_world(5) + + assert parent() == 'hello 5' + + assert exporter.exported_spans_as_dict(_include_pending_spans=True, _strip_function_qualname=False) == snapshot( + [ + { + 'name': 'parent', + 'context': {'trace_id': 1, 'span_id': 2, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'start_time': 1000000000, + 'end_time': 1000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_new_context..parent', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'name': 'hello-world {a=}', + 'context': {'trace_id': 2, 'span_id': 4, 'is_remote': False}, + 'parent': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'start_time': 2000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_new_context..hello_world', + 'a': 5, + 'logfire.msg_template': 'hello-world {a=}', + 'logfire.msg': 'hello-world a=5', + 'logfire.json_schema': '{"type":"object","properties":{"a":{}}}', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'attributes': { + 'a': 5, + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_instrument_new_context..hello_world', + 'code.lineno': 123, + 'logfire.json_schema': '{"type":"object","properties":{"a":{},"return":{}}}', + 'logfire.msg': 'hello-world a=5', + 'logfire.msg_template': 'hello-world {a=}', + 'logfire.span_type': 'span', + 'logfire.tags': ('test_instrument',), + 'return': 'hello 5', + }, + 'context': { + 'is_remote': False, + 'span_id': 3, + 'trace_id': 2, + }, + 'end_time': 3000000000, + 'links': [ + { + 'attributes': {}, + 'context': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + }, + ], + 'name': 'hello-world {a=}', + 'parent': None, + 'start_time': 2000000000, + }, + { + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_instrument_new_context..parent', + 'code.lineno': 123, + 'logfire.json_schema': '{"type":"object","properties":{"return":{}}}', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'span', + 'logfire.tags': ('test_instrument',), + 'return': 'hello 5', + }, + 'context': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + 'end_time': 4000000000, + 'name': 'parent', + 'parent': None, + 'start_time': 1000000000, + }, + ] + ) + + +def test_instrument_new_context_no_extract(exporter: TestExporter) -> None: + tagged = logfire.with_tags('test_instrument') + + @tagged.instrument('hello-world', record_return=True, new_context=True, extract_args=False) + def hello_world(a: int) -> str: + return f'hello {a}' + + @tagged.instrument('parent', record_return=True) + def parent() -> str: + return hello_world(5) + + assert parent() == 'hello 5' + + assert exporter.exported_spans_as_dict(_include_pending_spans=True, _strip_function_qualname=False) == snapshot( + [ + { + 'name': 'parent', + 'context': {'trace_id': 1, 'span_id': 2, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'start_time': 1000000000, + 'end_time': 1000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_new_context_no_extract..parent', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'name': 'hello-world', + 'context': {'trace_id': 2, 'span_id': 4, 'is_remote': False}, + 'parent': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'start_time': 2000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_new_context_no_extract..hello_world', + 'logfire.msg_template': 'hello-world', + 'logfire.msg': 'hello-world', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'context': { + 'is_remote': False, + 'span_id': 3, + 'trace_id': 2, + }, + 'end_time': 3000000000, + 'links': [ + { + 'attributes': {}, + 'context': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + }, + ], + 'name': 'hello-world', + 'parent': None, + 'attributes': { + 'code.function': 'test_instrument_new_context_no_extract..hello_world', + 'logfire.msg_template': 'hello-world', + 'code.lineno': 123, + 'code.filepath': 'test_logfire.py', + 'logfire.tags': ('test_instrument',), + 'logfire.span_type': 'span', + 'logfire.msg': 'hello-world', + 'return': 'hello 5', + 'logfire.json_schema': '{"type":"object","properties":{"return":{}}}', + }, + 'start_time': 2000000000, + }, + { + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_instrument_new_context_no_extract..parent', + 'code.lineno': 123, + 'logfire.json_schema': '{"type":"object","properties":{"return":{}}}', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'span', + 'logfire.tags': ('test_instrument',), + 'return': 'hello 5', + }, + 'context': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + 'end_time': 4000000000, + 'name': 'parent', + 'parent': None, + 'start_time': 1000000000, + }, + ] + ) + + +def test_instrument_new_context_some_args(exporter: TestExporter) -> None: + tagged = logfire.with_tags('test_instrument') + + @tagged.instrument('hello-world {a=}', record_return=True, new_context=True, extract_args=['a']) + def hello_world(a: int, b: int) -> str: + return f'hello {a} {b}' + + @tagged.instrument('parent', record_return=True) + def parent() -> str: + return hello_world(5, 10) + + assert parent() == 'hello 5 10' + + assert exporter.exported_spans_as_dict(_include_pending_spans=True, _strip_function_qualname=False) == snapshot( + [ + { + 'name': 'parent', + 'context': {'trace_id': 1, 'span_id': 2, 'is_remote': False}, + 'parent': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'start_time': 1000000000, + 'end_time': 1000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_new_context_some_args..parent', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'name': 'hello-world {a=}', + 'context': {'trace_id': 2, 'span_id': 4, 'is_remote': False}, + 'parent': {'trace_id': 2, 'span_id': 3, 'is_remote': False}, + 'start_time': 2000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.lineno': 123, + 'code.function': 'test_instrument_new_context_some_args..hello_world', + 'a': 5, + 'logfire.msg_template': 'hello-world {a=}', + 'logfire.msg': 'hello-world a=5', + 'logfire.json_schema': '{"type":"object","properties":{"a":{}}}', + 'logfire.span_type': 'pending_span', + 'logfire.pending_parent_id': '0000000000000000', + 'logfire.tags': ('test_instrument',), + }, + }, + { + 'attributes': { + 'a': 5, + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_instrument_new_context_some_args..hello_world', + 'code.lineno': 123, + 'logfire.json_schema': '{"type":"object","properties":{"a":{},"return":{}}}', + 'logfire.msg': 'hello-world a=5', + 'logfire.msg_template': 'hello-world {a=}', + 'logfire.span_type': 'span', + 'logfire.tags': ('test_instrument',), + 'return': 'hello 5 10', + }, + 'context': { + 'is_remote': False, + 'span_id': 3, + 'trace_id': 2, + }, + 'end_time': 3000000000, + 'links': [ + { + 'attributes': {}, + 'context': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + }, + ], + 'name': 'hello-world {a=}', + 'parent': None, + 'start_time': 2000000000, + }, + { + 'attributes': { + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_instrument_new_context_some_args..parent', + 'code.lineno': 123, + 'logfire.json_schema': '{"type":"object","properties":{"return":{}}}', + 'logfire.msg': 'parent', + 'logfire.msg_template': 'parent', + 'logfire.span_type': 'span', + 'logfire.tags': ('test_instrument',), + 'return': 'hello 5 10', + }, + 'context': { + 'is_remote': False, + 'span_id': 1, + 'trace_id': 1, + }, + 'end_time': 4000000000, + 'name': 'parent', + 'parent': None, + 'start_time': 1000000000, + }, + ] + ) + + def test_instrument_other_callable(exporter: TestExporter): class Instrumented: def __call__(self, a: int) -> str: