11
11
from opentelemetry .util import types as otel_types
12
12
from typing_extensions import LiteralString , ParamSpec
13
13
14
+ from .ast_utils import has_current_span_call
14
15
from .constants import ATTRIBUTES_MESSAGE_TEMPLATE_KEY , ATTRIBUTES_TAGS_KEY
15
16
from .stack_info import get_filepath_attribute
16
17
from .utils import safe_repr , uniquify_sequence
@@ -61,7 +62,8 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
61
62
)
62
63
63
64
attributes = get_attributes (func , msg_template , tags )
64
- open_span = get_open_span (logfire , attributes , span_name , extract_args , func )
65
+ uses_current_span = has_current_span_call (func )
66
+ open_span = get_open_span (logfire , attributes , span_name , extract_args , uses_current_span , func )
65
67
66
68
if inspect .isgeneratorfunction (func ):
67
69
if not allow_generator :
@@ -90,21 +92,31 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore
90
92
91
93
async def wrapper (* func_args : P .args , ** func_kwargs : P .kwargs ) -> R : # type: ignore
92
94
with open_span (* func_args , ** func_kwargs ) as span :
95
+ token = None
96
+ if uses_current_span :
97
+ token = logfire ._current_span_var .set (span ) # type: ignore[protected-access]
93
98
result = await func (* func_args , ** func_kwargs )
94
99
if record_return :
95
100
# open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types.
96
101
# This isn't great because it has to parse the JSON schema.
97
102
# Not sure if making get_open_span return a LogfireSpan when record_return is True
98
103
# would be faster overall or if it would be worth the added complexity.
99
104
set_user_attributes_on_raw_span (span ._span , {'return' : result })
105
+ if token :
106
+ logfire ._current_span_var .reset (token ) # type: ignore[protected-access]
100
107
return result
101
108
else :
102
109
# Same as the above, but without the async/await
103
110
def wrapper (* func_args : P .args , ** func_kwargs : P .kwargs ) -> R :
104
111
with open_span (* func_args , ** func_kwargs ) as span :
112
+ token = None
113
+ if uses_current_span :
114
+ token = logfire ._current_span_var .set (span ) # type: ignore[protected-access]
105
115
result = func (* func_args , ** func_kwargs )
106
116
if record_return :
107
117
set_user_attributes_on_raw_span (span ._span , {'return' : result })
118
+ if token :
119
+ logfire ._current_span_var .reset (token ) # type: ignore[protected-access]
108
120
return result
109
121
110
122
wrapper = functools .wraps (func )(wrapper ) # type: ignore
@@ -118,12 +130,15 @@ def get_open_span(
118
130
attributes : dict [str , otel_types .AttributeValue ],
119
131
span_name : str | None ,
120
132
extract_args : bool | Iterable [str ],
133
+ uses_current_span : bool ,
121
134
func : Callable [P , R ],
122
135
) -> Callable [P , AbstractContextManager [Any ]]:
123
136
final_span_name : str = span_name or attributes [ATTRIBUTES_MESSAGE_TEMPLATE_KEY ] # type: ignore
124
137
125
138
# This is the fast case for when there are no arguments to extract
126
139
def open_span (* _ : P .args , ** __ : P .kwargs ): # type: ignore
140
+ if uses_current_span :
141
+ return logfire ._span (final_span_name , attributes ) # type: ignore[protected-access]
127
142
return logfire ._fast_span (final_span_name , attributes ) # type: ignore
128
143
129
144
if extract_args is True :
@@ -134,6 +149,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
134
149
bound = sig .bind (* func_args , ** func_kwargs )
135
150
bound .apply_defaults ()
136
151
args_dict = bound .arguments
152
+ if uses_current_span :
153
+ return logfire ._span (final_span_name , {** attributes , ** args_dict }) # type: ignore[protected-access]
154
+
137
155
return logfire ._instrument_span_with_args ( # type: ignore
138
156
final_span_name , attributes , args_dict
139
157
)
@@ -165,6 +183,9 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
165
183
# This line is the only difference from the extract_args=True case
166
184
args_dict = {k : args_dict [k ] for k in extract_args_final }
167
185
186
+ if uses_current_span :
187
+ return logfire ._span (final_span_name , {** attributes , ** args_dict }) # type: ignore[protected-access]
188
+
168
189
return logfire ._instrument_span_with_args ( # type: ignore
169
190
final_span_name , attributes , args_dict
170
191
)
0 commit comments