4
4
import inspect
5
5
from collections .abc import Awaitable , Iterable
6
6
from contextlib import AbstractContextManager , contextmanager
7
+ from datetime import datetime , timezone
7
8
from functools import lru_cache
8
9
from typing import Any , Callable
9
10
from weakref import WeakKeyDictionary
17
18
from starlette .responses import Response
18
19
from starlette .websockets import WebSocket
19
20
20
- from ..main import Logfire , set_user_attributes_on_raw_span
21
+ from ..constants import ONE_SECOND_IN_NANOSECONDS
22
+ from ..main import Logfire , NoopSpan , set_user_attributes_on_raw_span
21
23
from ..stack_info import StackInfo , get_code_object_info
22
24
from ..utils import handle_internal_errors , maybe_capture_server_headers
23
25
@@ -61,6 +63,7 @@ def instrument_fastapi(
61
63
| None = None ,
62
64
excluded_urls : str | Iterable [str ] | None = None ,
63
65
record_send_receive : bool = False ,
66
+ extra_spans : bool = True ,
64
67
** opentelemetry_kwargs : Any ,
65
68
) -> AbstractContextManager [None ]:
66
69
"""Instrument a FastAPI app so that spans and logs are automatically created for each request.
@@ -96,6 +99,7 @@ def instrument_fastapi(
96
99
registry [_app ] = FastAPIInstrumentation (
97
100
logfire_instance ,
98
101
request_attributes_mapper or _default_request_attributes_mapper ,
102
+ extra_spans = extra_spans ,
99
103
)
100
104
101
105
@contextmanager
@@ -158,16 +162,39 @@ def __init__(
158
162
],
159
163
dict [str , Any ] | None ,
160
164
],
165
+ extra_spans : bool ,
161
166
):
162
167
self .logfire_instance = logfire_instance .with_settings (custom_scope_suffix = 'fastapi' )
168
+ self .timestamp_generator = self .logfire_instance .config .advanced .ns_timestamp_generator
163
169
self .request_attributes_mapper = request_attributes_mapper
170
+ self .extra_spans = extra_spans
171
+
172
+ @contextmanager
173
+ def pseudo_span (self , namespace : str , root_span : Span ):
174
+ """Record start and end timestamps in the root span, and possibly exceptions."""
175
+
176
+ def set_timestamp (attribute_name : str ):
177
+ dt = datetime .fromtimestamp (self .timestamp_generator () / ONE_SECOND_IN_NANOSECONDS , tz = timezone .utc )
178
+ value = dt .strftime ('%Y-%m-%dT%H:%M:%S.%fZ' )
179
+ root_span .set_attribute (f'fastapi.{ namespace } .{ attribute_name } ' , value )
180
+
181
+ set_timestamp ('start_timestamp' )
182
+ try :
183
+ try :
184
+ yield
185
+ finally :
186
+ # Record the end timestamp before recording exceptions.
187
+ set_timestamp ('end_timestamp' )
188
+ except Exception as exc :
189
+ root_span .record_exception (exc )
190
+ raise
164
191
165
192
async def solve_dependencies (self , request : Request | WebSocket , original : Awaitable [Any ]) -> Any :
166
193
root_span = request .scope .get (LOGFIRE_SPAN_SCOPE_KEY )
167
194
if not (root_span and root_span .is_recording ()):
168
195
return await original
169
196
170
- with self .logfire_instance .span ('FastAPI arguments' ) as span :
197
+ with self .logfire_instance .span ('FastAPI arguments' ) if self . extra_spans else NoopSpan () as span :
171
198
with handle_internal_errors :
172
199
if isinstance (request , Request ): # pragma: no branch
173
200
span .set_attribute ('http.method' , request .method )
@@ -180,7 +207,8 @@ async def solve_dependencies(self, request: Request | WebSocket, original: Await
180
207
set_user_attributes_on_raw_span (root_span , fastapi_route_attributes )
181
208
span .set_attributes (fastapi_route_attributes )
182
209
183
- result : Any = await original
210
+ with self .pseudo_span ('arguments' , root_span ):
211
+ result : Any = await original
184
212
185
213
with handle_internal_errors :
186
214
solved_values : dict [str , Any ]
@@ -228,7 +256,8 @@ def solved_with_new_values(new_values: dict[str, Any]) -> Any:
228
256
229
257
# request_attributes_mapper may have removed the errors, so we need .get() here.
230
258
if attributes .get ('errors' ):
231
- span .set_level ('error' )
259
+ # Errors should imply a 422 response. 4xx errors are warnings, not errors.
260
+ span .set_level ('warn' )
232
261
233
262
span .set_attributes (attributes )
234
263
for key in ('values' , 'errors' ):
@@ -246,20 +275,31 @@ async def run_endpoint_function(
246
275
values : dict [str , Any ],
247
276
** kwargs : Any ,
248
277
) -> Any :
249
- callback = inspect .unwrap (dependant .call )
250
- code = getattr (callback , '__code__' , None )
251
- stack_info : StackInfo = get_code_object_info (code ) if code else {}
252
- with self .logfire_instance .span (
253
- '{method} {http.route} ({code.function})' ,
254
- method = request .method ,
255
- # Using `http.route` prevents it from being scrubbed if it contains a word like 'secret'.
256
- # We don't use `http.method` because some dashboards do things like count spans with
257
- # both `http.method` and `http.route`.
258
- ** {'http.route' : request .scope ['route' ].path },
259
- ** stack_info ,
260
- _level = 'debug' ,
261
- ):
262
- return await original_run_endpoint_function (dependant = dependant , values = values , ** kwargs )
278
+ original = original_run_endpoint_function (dependant = dependant , values = values , ** kwargs )
279
+ root_span = request .scope .get (LOGFIRE_SPAN_SCOPE_KEY )
280
+ if not (root_span and root_span .is_recording ()): # pragma: no cover
281
+ # This should never happen because we only get to this function after solve_dependencies
282
+ # passes the same check, just being paranoid.
283
+ return await original
284
+
285
+ if self .extra_spans :
286
+ callback = inspect .unwrap (dependant .call )
287
+ code = getattr (callback , '__code__' , None )
288
+ stack_info : StackInfo = get_code_object_info (code ) if code else {}
289
+ extra_span = self .logfire_instance .span (
290
+ '{method} {http.route} ({code.function})' ,
291
+ method = request .method ,
292
+ # Using `http.route` prevents it from being scrubbed if it contains a word like 'secret'.
293
+ # We don't use `http.method` because some dashboards do things like count spans with
294
+ # both `http.method` and `http.route`.
295
+ ** {'http.route' : request .scope ['route' ].path },
296
+ ** stack_info ,
297
+ _level = 'debug' ,
298
+ )
299
+ else :
300
+ extra_span = NoopSpan ()
301
+ with extra_span , self .pseudo_span ('endpoint_function' , root_span ):
302
+ return await original
263
303
264
304
265
305
def _default_request_attributes_mapper (
0 commit comments