1+ import asyncio
12import functools
23import typing as t
34
@@ -183,6 +184,7 @@ class Toolset(Model):
183184 # Context manager magic
184185 _entry_ref_count : int = PrivateAttr (default = 0 )
185186 _context_handle : object = PrivateAttr (default = None )
187+ _entry_lock : asyncio .Lock = PrivateAttr (default_factory = asyncio .Lock )
186188
187189 @property
188190 def name (self ) -> str :
@@ -200,11 +202,28 @@ def __init_subclass__(cls, **kwargs: t.Any) -> None:
200202
201203 original_aenter = cls .__dict__ .get ("__aenter__" )
202204 original_enter = cls .__dict__ .get ("__enter__" )
205+ original_aexit = cls .__dict__ .get ("__aexit__" )
206+ original_exit = cls .__dict__ .get ("__exit__" )
203207
204- if callable (original_aenter ) or callable (original_enter ):
205-
206- @functools .wraps (original_aenter or original_enter ) # type: ignore[arg-type]
207- async def aenter_wrapper (self : "Toolset" , * args : t .Any , ** kwargs : t .Any ) -> t .Any :
208+ has_enter = callable (original_aenter ) or callable (original_enter )
209+ has_exit = callable (original_aexit ) or callable (original_exit )
210+
211+ if has_enter and not has_exit :
212+ raise TypeError (
213+ f"{ cls .__name__ } defining __aenter__ or __enter__ must also define __aexit__ or __exit__"
214+ )
215+ if has_exit and not has_enter :
216+ raise TypeError (
217+ f"{ cls .__name__ } defining __aexit__ or __exit__ must also define __aenter__ or __enter__"
218+ )
219+ if original_aenter and original_enter :
220+ raise TypeError (f"{ cls .__name__ } cannot define both __aenter__ and __enter__" )
221+ if original_aexit and original_exit :
222+ raise TypeError (f"{ cls .__name__ } cannot define both __aexit__ and __exit__" )
223+
224+ @functools .wraps (original_aenter or original_enter ) # type: ignore[arg-type]
225+ async def aenter_wrapper (self : "Toolset" , * args : t .Any , ** kwargs : t .Any ) -> t .Any :
226+ async with self ._entry_lock :
208227 if self ._entry_ref_count == 0 :
209228 handle = None
210229 if original_aenter :
@@ -215,15 +234,11 @@ async def aenter_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.An
215234 self ._entry_ref_count += 1
216235 return self ._context_handle
217236
218- cls .__aenter__ = aenter_wrapper # type: ignore[attr-defined]
219-
220- original_aexit = cls .__dict__ .get ("__aexit__" )
221- original_exit = cls .__dict__ .get ("__exit__" )
222-
223- if callable (original_aexit ) or callable (original_exit ):
237+ cls .__aenter__ = aenter_wrapper # type: ignore[attr-defined]
224238
225- @functools .wraps (original_aexit or original_exit ) # type: ignore[arg-type]
226- async def aexit_wrapper (self : "Toolset" , * args : t .Any , ** kwargs : t .Any ) -> t .Any :
239+ @functools .wraps (original_aexit or original_exit ) # type: ignore[arg-type]
240+ async def aexit_wrapper (self : "Toolset" , * args : t .Any , ** kwargs : t .Any ) -> t .Any :
241+ async with self ._entry_lock :
227242 self ._entry_ref_count -= 1
228243 if self ._entry_ref_count == 0 :
229244 if original_aexit :
@@ -232,7 +247,7 @@ async def aexit_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any
232247 original_exit (self , * args , ** kwargs )
233248 self ._context_handle = None
234249
235- cls .__aexit__ = aexit_wrapper # type: ignore[attr-defined]
250+ cls .__aexit__ = aexit_wrapper # type: ignore[attr-defined]
236251
237252 def get_tools (self , * , variant : str | None = None ) -> list [AnyTool ]:
238253 variant = variant or self .variant
0 commit comments