@@ -63,19 +63,27 @@ def decorator(func: Callable[P, R]) -> Callable[P, R]:
63
63
attributes = get_attributes (func , msg_template , tags )
64
64
open_span = get_open_span (logfire , attributes , span_name , extract_args , func )
65
65
66
+ # Check if function has logfire_span parameter
67
+ sig = inspect .signature (func )
68
+ has_logfire_span_param = 'logfire_span' in sig .parameters
69
+
66
70
if inspect .isgeneratorfunction (func ):
67
71
if not allow_generator :
68
72
warnings .warn (GENERATOR_WARNING_MESSAGE , stacklevel = 2 )
69
73
70
74
def wrapper (* func_args : P .args , ** func_kwargs : P .kwargs ): # type: ignore
71
- with open_span (* func_args , ** func_kwargs ):
75
+ with open_span (* func_args , ** func_kwargs ) as span :
76
+ if has_logfire_span_param :
77
+ func_kwargs ['logfire_span' ] = span
72
78
yield from func (* func_args , ** func_kwargs )
73
79
elif inspect .isasyncgenfunction (func ):
74
80
if not allow_generator :
75
81
warnings .warn (GENERATOR_WARNING_MESSAGE , stacklevel = 2 )
76
82
77
83
async def wrapper (* func_args : P .args , ** func_kwargs : P .kwargs ): # type: ignore
78
- with open_span (* func_args , ** func_kwargs ):
84
+ with open_span (* func_args , ** func_kwargs ) as span :
85
+ if has_logfire_span_param :
86
+ func_kwargs ['logfire_span' ] = span
79
87
# `yield from` is invalid syntax in an async function.
80
88
# This loop is not quite equivalent, because `yield from` also handles things like
81
89
# sending values to the subgenerator.
@@ -90,6 +98,8 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs): # type: ignore
90
98
91
99
async def wrapper (* func_args : P .args , ** func_kwargs : P .kwargs ) -> R : # type: ignore
92
100
with open_span (* func_args , ** func_kwargs ) as span :
101
+ if has_logfire_span_param :
102
+ func_kwargs ['logfire_span' ] = span
93
103
result = await func (* func_args , ** func_kwargs )
94
104
if record_return :
95
105
# open_span returns a FastLogfireSpan, so we can't use span.set_attribute for complex types.
@@ -102,6 +112,8 @@ async def wrapper(*func_args: P.args, **func_kwargs: P.kwargs) -> R: # type: ig
102
112
# Same as the above, but without the async/await
103
113
def wrapper (* func_args : P .args , ** func_kwargs : P .kwargs ) -> R :
104
114
with open_span (* func_args , ** func_kwargs ) as span :
115
+ if has_logfire_span_param :
116
+ func_kwargs ['logfire_span' ] = span
105
117
result = func (* func_args , ** func_kwargs )
106
118
if record_return :
107
119
set_user_attributes_on_raw_span (span ._span , {'return' : result })
@@ -122,27 +134,32 @@ def get_open_span(
122
134
) -> Callable [P , AbstractContextManager [Any ]]:
123
135
final_span_name : str = span_name or attributes [ATTRIBUTES_MESSAGE_TEMPLATE_KEY ] # type: ignore
124
136
137
+ # Check if function has logfire_span parameter
138
+ sig = inspect .signature (func )
139
+ has_logfire_span_param = 'logfire_span' in sig .parameters
140
+
125
141
# This is the fast case for when there are no arguments to extract
126
142
def open_span (* _ : P .args , ** __ : P .kwargs ): # type: ignore
143
+ if has_logfire_span_param :
144
+ return logfire ._span (final_span_name , attributes ) # type: ignore
127
145
return logfire ._fast_span (final_span_name , attributes ) # type: ignore
128
146
129
147
if extract_args is True :
130
- sig = inspect .signature (func )
131
148
if sig .parameters : # only extract args if there are any
132
149
133
150
def open_span (* func_args : P .args , ** func_kwargs : P .kwargs ):
134
151
bound = sig .bind (* func_args , ** func_kwargs )
135
152
bound .apply_defaults ()
136
153
args_dict = bound .arguments
154
+ if has_logfire_span_param :
155
+ return logfire ._span (final_span_name , {** attributes , ** args_dict }) # type: ignore
137
156
return logfire ._instrument_span_with_args ( # type: ignore
138
157
final_span_name , attributes , args_dict
139
158
)
140
159
141
160
return open_span
142
161
143
162
if extract_args : # i.e. extract_args should be an iterable of argument names
144
- sig = inspect .signature (func )
145
-
146
163
if isinstance (extract_args , str ):
147
164
extract_args = [extract_args ]
148
165
@@ -165,6 +182,8 @@ def open_span(*func_args: P.args, **func_kwargs: P.kwargs):
165
182
# This line is the only difference from the extract_args=True case
166
183
args_dict = {k : args_dict [k ] for k in extract_args_final }
167
184
185
+ if has_logfire_span_param :
186
+ return logfire ._span (final_span_name , {** attributes , ** args_dict }) # type: ignore
168
187
return logfire ._instrument_span_with_args ( # type: ignore
169
188
final_span_name , attributes , args_dict
170
189
)
0 commit comments