33import ast
44import uuid
55from collections import deque
6+ from contextlib import AbstractContextManager
67from dataclasses import dataclass
78from functools import partial
8- from typing import TYPE_CHECKING , Any , Callable , ContextManager , TypeVar
9+ from typing import TYPE_CHECKING , Any , Callable , TypeVar
910
1011import logfire
1112
@@ -35,7 +36,7 @@ def compile_source(
3536 Otherwise, it's initially the `partial` above.
3637 """
3738 logfire_name = f'logfire_{ uuid .uuid4 ().hex } '
38- context_factories : list [Callable [[], ContextManager [Any ]]] = []
39+ context_factories : list [Callable [[], AbstractContextManager [Any ]]] = []
3940 tree = rewrite_ast (tree , filename , logfire_name , module_name , logfire_instance , context_factories , min_duration )
4041 assert isinstance (tree , ast .Module ) # for type checking
4142 # dont_inherit=True is necessary to prevent the module from inheriting the __future__ import from this module.
@@ -54,7 +55,7 @@ def rewrite_ast(
5455 logfire_name : str ,
5556 module_name : str ,
5657 logfire_instance : Logfire ,
57- context_factories : list [Callable [[], ContextManager [Any ]]],
58+ context_factories : list [Callable [[], AbstractContextManager [Any ]]],
5859 min_duration : int ,
5960) -> ast .AST :
6061 logfire_args = LogfireArgs (logfire_instance ._tags , logfire_instance ._sample_rate ) # type: ignore
@@ -69,7 +70,7 @@ class AutoTraceTransformer(BaseTransformer):
6970 """Trace all encountered functions except those explicitly marked with `@no_auto_trace`."""
7071
7172 logfire_instance : Logfire
72- context_factories : list [Callable [[], ContextManager [Any ]]]
73+ context_factories : list [Callable [[], AbstractContextManager [Any ]]]
7374 min_duration : int
7475
7576 def check_no_auto_trace (self , node : ast .FunctionDef | ast .AsyncFunctionDef | ast .ClassDef ) -> bool :
0 commit comments