Skip to content

Commit 8e39500

Browse files
authored
Fix typing errors involving handle_internal_errors (#885)
1 parent a423415 commit 8e39500

File tree

10 files changed

+90
-63
lines changed

10 files changed

+90
-63
lines changed

logfire/_internal/config.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -317,22 +317,22 @@ def configure( # noqa: D417
317317
"""
318318
from .. import DEFAULT_LOGFIRE_INSTANCE, Logfire
319319

320-
processors = deprecated_kwargs.pop('processors', None) # type: ignore
320+
processors = deprecated_kwargs.pop('processors', None)
321321
if processors is not None: # pragma: no cover
322322
raise ValueError(
323323
'The `processors` argument has been replaced by `additional_span_processors`. '
324324
'Set `send_to_logfire=False` to disable the default processor.'
325325
)
326326

327-
metric_readers = deprecated_kwargs.pop('metric_readers', None) # type: ignore
327+
metric_readers = deprecated_kwargs.pop('metric_readers', None)
328328
if metric_readers is not None: # pragma: no cover
329329
raise ValueError(
330330
'The `metric_readers` argument has been replaced by '
331331
'`metrics=logfire.MetricsOptions(additional_readers=[...])`. '
332332
'Set `send_to_logfire=False` to disable the default metric reader.'
333333
)
334334

335-
collect_system_metrics = deprecated_kwargs.pop('collect_system_metrics', None) # type: ignore
335+
collect_system_metrics = deprecated_kwargs.pop('collect_system_metrics', None)
336336
if collect_system_metrics is False:
337337
raise ValueError(
338338
'The `collect_system_metrics` argument has been removed. System metrics are no longer collected by default.'
@@ -343,8 +343,8 @@ def configure( # noqa: D417
343343
'The `collect_system_metrics` argument has been removed. Use `logfire.instrument_system_metrics()` instead.'
344344
)
345345

346-
scrubbing_callback = deprecated_kwargs.pop('scrubbing_callback', None) # type: ignore
347-
scrubbing_patterns = deprecated_kwargs.pop('scrubbing_patterns', None) # type: ignore
346+
scrubbing_callback = deprecated_kwargs.pop('scrubbing_callback', None)
347+
scrubbing_patterns = deprecated_kwargs.pop('scrubbing_patterns', None)
348348
if scrubbing_callback or scrubbing_patterns:
349349
if scrubbing is not None:
350350
raise ValueError(
@@ -357,7 +357,7 @@ def configure( # noqa: D417
357357
)
358358
scrubbing = ScrubbingOptions(callback=scrubbing_callback, extra_patterns=scrubbing_patterns) # type: ignore
359359

360-
project_name = deprecated_kwargs.pop('project_name', None) # type: ignore
360+
project_name = deprecated_kwargs.pop('project_name', None)
361361
if project_name is not None:
362362
warnings.warn(
363363
'The `project_name` argument is deprecated and not needed.',
@@ -377,15 +377,15 @@ def configure( # noqa: D417
377377
'Use `sampling=logfire.SamplingOptions(head=...)` instead.',
378378
)
379379

380-
show_summary = deprecated_kwargs.pop('show_summary', None) # type: ignore
380+
show_summary = deprecated_kwargs.pop('show_summary', None)
381381
if show_summary is not None: # pragma: no cover
382382
warnings.warn(
383383
'The `show_summary` argument is deprecated. '
384384
'Use `console=False` or `console=logfire.ConsoleOptions(show_project_link=False)` instead.',
385385
)
386386

387387
for key in ('base_url', 'id_generator', 'ns_timestamp_generator'):
388-
value: Any = deprecated_kwargs.pop(key, None) # type: ignore
388+
value: Any = deprecated_kwargs.pop(key, None)
389389
if value is None:
390390
continue
391391
if advanced is not None:
@@ -397,7 +397,7 @@ def configure( # noqa: D417
397397
stacklevel=2,
398398
)
399399

400-
additional_metric_readers: Any = deprecated_kwargs.pop('additional_metric_readers', None) # type: ignore
400+
additional_metric_readers: Any = deprecated_kwargs.pop('additional_metric_readers', None)
401401
if additional_metric_readers:
402402
if metrics is not None:
403403
raise ValueError(
@@ -410,7 +410,7 @@ def configure( # noqa: D417
410410
)
411411
metrics = MetricsOptions(additional_readers=additional_metric_readers)
412412

413-
pydantic_plugin: Any = deprecated_kwargs.pop('pydantic_plugin', None) # type: ignore
413+
pydantic_plugin: Any = deprecated_kwargs.pop('pydantic_plugin', None)
414414
if pydantic_plugin is not None:
415415
warnings.warn(
416416
'The `pydantic_plugin` argument is deprecated. Use `logfire.instrument_pydantic()` instead.',
@@ -994,7 +994,7 @@ def check_token():
994994
if hasattr(os, 'register_at_fork'): # pragma: no branch
995995

996996
def fix_pid(): # pragma: no cover
997-
with handle_internal_errors():
997+
with handle_internal_errors:
998998
new_resource = resource.merge(Resource({ResourceAttributes.PROCESS_PID: os.getpid()}))
999999
tracer_provider._resource = new_resource # type: ignore
10001000
meter_provider._resource = new_resource # type: ignore

logfire/_internal/integrations/fastapi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ async def solve_dependencies(self, request: Request | WebSocket, original: Await
167167
return await original
168168

169169
with self.logfire_instance.span('FastAPI arguments') as span:
170-
with handle_internal_errors():
170+
with handle_internal_errors:
171171
if isinstance(request, Request): # pragma: no branch
172172
span.set_attribute(SpanAttributes.HTTP_METHOD, request.method)
173173
route: APIRoute | APIWebSocketRoute | None = request.scope.get('route')
@@ -181,7 +181,7 @@ async def solve_dependencies(self, request: Request | WebSocket, original: Await
181181

182182
result: Any = await original
183183

184-
with handle_internal_errors():
184+
with handle_internal_errors:
185185
solved_values: dict[str, Any]
186186
solved_errors: list[Any]
187187

logfire/_internal/integrations/httpx.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def make_request_hook(hook: RequestHook | None, capture_headers: bool, capture_b
322322
return None
323323

324324
def new_hook(span: Span, request: RequestInfo) -> None:
325-
with handle_internal_errors():
325+
with handle_internal_errors:
326326
request = capture_request(span, request, capture_headers, capture_body)
327327
run_hook(hook, span, request)
328328

@@ -338,7 +338,7 @@ def make_async_request_hook(
338338
return None
339339

340340
async def new_hook(span: Span, request: RequestInfo) -> None:
341-
with handle_internal_errors():
341+
with handle_internal_errors:
342342
request = capture_request(span, request, should_capture_headers, should_capture_body)
343343
await run_async_hook(hook, span, request)
344344

@@ -355,7 +355,7 @@ def make_response_hook(
355355
return None
356356

357357
def new_hook(span: Span, request: RequestInfo, response: ResponseInfo) -> None:
358-
with handle_internal_errors():
358+
with handle_internal_errors:
359359
request, response = capture_response(
360360
span,
361361
request,
@@ -380,7 +380,7 @@ def make_async_response_hook(
380380
return None
381381

382382
async def new_hook(span: Span, request: RequestInfo, response: ResponseInfo) -> None:
383-
with handle_internal_errors():
383+
with handle_internal_errors:
384384
request, response = capture_response(
385385
span,
386386
request,

logfire/_internal/integrations/psycopg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def instrument_psycopg(
4141
instrument_psycopg(logfire_instance, package, **kwargs)
4242
return
4343
elif conn_or_module in PACKAGE_NAMES:
44+
assert isinstance(conn_or_module, str)
4445
_instrument_psycopg(logfire_instance, name=conn_or_module, **kwargs)
4546
return
4647
elif isinstance(conn_or_module, ModuleType):

logfire/_internal/main.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def log(
662662
Set to `True` to use the currently handled exception.
663663
console_log: Whether to log to the console, defaults to `True`.
664664
"""
665-
with handle_internal_errors():
665+
with handle_internal_errors:
666666
stack_info = get_user_stack_info()
667667

668668
attributes = attributes or {}
@@ -2107,7 +2107,7 @@ def __init__(self, span: trace_api.Span) -> None:
21072107
def __enter__(self) -> FastLogfireSpan:
21082108
return self
21092109

2110-
@handle_internal_errors()
2110+
@handle_internal_errors
21112111
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
21122112
context_api.detach(self._token)
21132113
self._span.__exit__(exc_type, exc_value, traceback)
@@ -2139,7 +2139,7 @@ def __getattr__(self, name: str) -> Any:
21392139
return getattr(self._span, name)
21402140

21412141
def __enter__(self) -> LogfireSpan:
2142-
with handle_internal_errors():
2142+
with handle_internal_errors:
21432143
if self._span is None: # pragma: no branch
21442144
self._span = self._tracer.start_span(
21452145
name=self._span_name,
@@ -2152,7 +2152,7 @@ def __enter__(self) -> LogfireSpan:
21522152

21532153
return self
21542154

2155-
@handle_internal_errors()
2155+
@handle_internal_errors
21562156
def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None:
21572157
if self._token is None: # pragma: no cover
21582158
return
@@ -2161,7 +2161,7 @@ def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseExceptio
21612161
context_api.detach(self._token)
21622162
self._token = None
21632163
if self._span.is_recording():
2164-
with handle_internal_errors():
2164+
with handle_internal_errors:
21652165
if self._added_attributes:
21662166
self._span.set_attribute(
21672167
ATTRIBUTES_JSON_SCHEMA_KEY, attributes_json_schema(self._json_schema_properties)
@@ -2177,7 +2177,7 @@ def tags(self) -> tuple[str, ...]:
21772177
return self._get_attribute(ATTRIBUTES_TAGS_KEY, ())
21782178

21792179
@tags.setter
2180-
@handle_internal_errors()
2180+
@handle_internal_errors
21812181
def tags(self, new_tags: Sequence[str]) -> None:
21822182
"""Set or add tags to the span."""
21832183
if isinstance(new_tags, str):
@@ -2192,7 +2192,7 @@ def message(self) -> str:
21922192
def message(self, message: str):
21932193
self._set_attribute(ATTRIBUTES_MESSAGE_KEY, message)
21942194

2195-
@handle_internal_errors()
2195+
@handle_internal_errors
21962196
def set_attribute(self, key: str, value: Any) -> None:
21972197
"""Sets an attribute on the span.
21982198
@@ -2247,7 +2247,7 @@ def record_exception(
22472247
def is_recording(self) -> bool:
22482248
return self._span is not None and self._span.is_recording()
22492249

2250-
@handle_internal_errors()
2250+
@handle_internal_errors
22512251
def set_level(self, level: LevelName | int):
22522252
"""Set the log level of this span."""
22532253
attributes = log_level_attributes(level)

logfire/_internal/tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def get_sample_rate_from_attributes(attributes: otel_types.Attributes) -> float
345345
return cast('float | None', attributes.get(ATTRIBUTES_SAMPLE_RATE_KEY))
346346

347347

348-
@handle_internal_errors()
348+
@handle_internal_errors
349349
def record_exception(
350350
span: trace_api.Span,
351351
exception: BaseException,

logfire/_internal/utils.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import functools
34
import inspect
45
import json
56
import logging
@@ -12,7 +13,19 @@
1213
from pathlib import Path
1314
from time import time
1415
from types import TracebackType
15-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Sequence, Tuple, TypedDict, TypeVar, Union
16+
from typing import (
17+
TYPE_CHECKING,
18+
Any,
19+
Callable,
20+
Dict,
21+
List,
22+
Mapping,
23+
Sequence,
24+
Tuple,
25+
TypedDict,
26+
TypeVar,
27+
Union,
28+
)
1629

1730
from opentelemetry import context, trace as trace_api
1831
from opentelemetry.sdk.resources import Resource
@@ -27,13 +40,17 @@
2740
from logfire._internal.ulid import ulid
2841

2942
if TYPE_CHECKING:
43+
from typing import ParamSpec
44+
3045
from packaging.version import Version
3146

3247
SysExcInfo = Union[tuple[type[BaseException], BaseException, TracebackType | None], tuple[None, None, None]]
3348
"""
3449
The return type of sys.exc_info(): exc_type, exc_val, exc_tb.
3550
"""
3651

52+
P = ParamSpec('P')
53+
3754
T = TypeVar('T')
3855

3956
JsonValue = Union[int, float, str, bool, None, List['JsonValue'], Tuple['JsonValue', ...], 'JsonDict']
@@ -280,20 +297,9 @@ def _internal_error_exc_info() -> SysExcInfo:
280297
original_exc_info: SysExcInfo = sys.exc_info()
281298
exc_type, exc_val, original_tb = original_exc_info
282299
try:
283-
# First remove redundant frames already in the traceback about where the error was raised.
284300
tb = original_tb
285-
if tb and tb.tb_frame and tb.tb_frame.f_code is _HANDLE_INTERNAL_ERRORS_CODE:
286-
# Skip the 'yield' line in _handle_internal_errors
287-
tb = tb.tb_next
288-
289-
if (
290-
tb
291-
and tb.tb_frame
292-
and tb.tb_frame.f_code.co_filename == contextmanager.__code__.co_filename
293-
and tb.tb_frame.f_code.co_name == 'inner'
294-
):
295-
# Skip the 'inner' function frame when handle_internal_errors is used as a decorator.
296-
# It looks like `return func(*args, **kwds)`
301+
if tb and tb.tb_frame and tb.tb_frame.f_code is _HANDLE_INTERNAL_ERRORS_WRAPPER_CODE:
302+
# Skip the redundant `with self:` line in the traceback about where the error was raised.
297303
tb = tb.tb_next
298304

299305
# Now add useful outer frames that give context, but skipping frames that are just about handling the error.
@@ -307,18 +313,16 @@ def _internal_error_exc_info() -> SysExcInfo:
307313
frame = frame.f_back
308314
assert frame
309315

310-
if frame.f_code is _HANDLE_INTERNAL_ERRORS_CODE:
311-
# Skip the line in _handle_internal_errors that calls log_internal_error
312-
frame = frame.f_back
313-
# Skip the frame defining the _handle_internal_errors context manager
314-
assert frame and frame.f_code.co_name == '__exit__'
316+
if frame.f_code is _HANDLE_INTERNAL_ERRORS_EXIT_CODE:
317+
# Skip the `log_internal_error()` call in `__exit__`.
315318
frame = frame.f_back
316319
assert frame
317-
# Skip the frame calling the context manager, on the `with` line.
318-
frame = frame.f_back
319-
else:
320-
# `log_internal_error()` was called directly, so just skip that frame. No context manager stuff.
321-
frame = frame.f_back
320+
321+
# Now skip the line that is either:
322+
# - A direct call to `log_internal_error`
323+
# - `with self:` in HandleInternalErrors.__call__.wrapper
324+
# - `with handle_internal_errors:`
325+
frame = frame.f_back
322326

323327
# Now add all remaining frames from internal logfire code.
324328
while frame and not is_user_code(frame.f_code):
@@ -340,15 +344,29 @@ def _internal_error_exc_info() -> SysExcInfo:
340344
return original_exc_info
341345

342346

343-
@contextmanager
344-
def handle_internal_errors():
345-
try:
346-
yield
347-
except Exception:
348-
log_internal_error()
347+
class HandleInternalErrors:
348+
def __enter__(self):
349+
pass
350+
351+
def __exit__(self, exc_type: type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> bool | None:
352+
if isinstance(exc_val, Exception):
353+
log_internal_error()
354+
return True
355+
356+
def __call__(self, func: Callable[P, T]) -> Callable[P, T]:
357+
@functools.wraps(func)
358+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
359+
with self:
360+
return func(*args, **kwargs)
361+
362+
return wrapper
363+
349364

365+
handle_internal_errors = HandleInternalErrors()
350366

351-
_HANDLE_INTERNAL_ERRORS_CODE = inspect.unwrap(handle_internal_errors).__code__
367+
_HANDLE_INTERNAL_ERRORS_WRAPPER_CODE = handle_internal_errors(log_internal_error).__code__
368+
assert _HANDLE_INTERNAL_ERRORS_WRAPPER_CODE.co_name == 'wrapper'
369+
_HANDLE_INTERNAL_ERRORS_EXIT_CODE = HandleInternalErrors.__exit__.__code__
352370

353371

354372
def maybe_capture_server_headers(capture: bool):

tests/import_used_for_tests/internal_error_handling/internal_logfire_code_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ def inner2():
1111
inner1()
1212

1313

14-
@handle_internal_errors()
14+
@handle_internal_errors
1515
def using_decorator():
1616
inner2()
1717

1818

1919
def using_context_manager():
20-
with handle_internal_errors():
20+
with handle_internal_errors:
2121
inner2()
2222

2323

0 commit comments

Comments
 (0)