Skip to content

Commit fcaa74e

Browse files
committed
Some logic cleanup
1 parent eeaf0c6 commit fcaa74e

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

dreadnode/agent/tools/base.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import functools
23
import typing as t
34

@@ -183,6 +184,7 @@ class Toolset(Model):
183184
# Context manager magic
184185
_entry_ref_count: int = PrivateAttr(default=0)
185186
_context_handle: object = PrivateAttr(default=None)
187+
_entry_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
186188

187189
@property
188190
def name(self) -> str:
@@ -200,11 +202,28 @@ def __init_subclass__(cls, **kwargs: t.Any) -> None:
200202

201203
original_aenter = cls.__dict__.get("__aenter__")
202204
original_enter = cls.__dict__.get("__enter__")
205+
original_aexit = cls.__dict__.get("__aexit__")
206+
original_exit = cls.__dict__.get("__exit__")
203207

204-
if callable(original_aenter) or callable(original_enter):
205-
206-
@functools.wraps(original_aenter or original_enter) # type: ignore[arg-type]
207-
async def aenter_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any:
208+
has_enter = callable(original_aenter) or callable(original_enter)
209+
has_exit = callable(original_aexit) or callable(original_exit)
210+
211+
if has_enter and not has_exit:
212+
raise TypeError(
213+
f"{cls.__name__} defining __aenter__ or __enter__ must also define __aexit__ or __exit__"
214+
)
215+
if has_exit and not has_enter:
216+
raise TypeError(
217+
f"{cls.__name__} defining __aexit__ or __exit__ must also define __aenter__ or __enter__"
218+
)
219+
if original_aenter and original_enter:
220+
raise TypeError(f"{cls.__name__} cannot define both __aenter__ and __enter__")
221+
if original_aexit and original_exit:
222+
raise TypeError(f"{cls.__name__} cannot define both __aexit__ and __exit__")
223+
224+
@functools.wraps(original_aenter or original_enter) # type: ignore[arg-type]
225+
async def aenter_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any:
226+
async with self._entry_lock:
208227
if self._entry_ref_count == 0:
209228
handle = None
210229
if original_aenter:
@@ -215,15 +234,11 @@ async def aenter_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.An
215234
self._entry_ref_count += 1
216235
return self._context_handle
217236

218-
cls.__aenter__ = aenter_wrapper # type: ignore[attr-defined]
219-
220-
original_aexit = cls.__dict__.get("__aexit__")
221-
original_exit = cls.__dict__.get("__exit__")
222-
223-
if callable(original_aexit) or callable(original_exit):
237+
cls.__aenter__ = aenter_wrapper # type: ignore[attr-defined]
224238

225-
@functools.wraps(original_aexit or original_exit) # type: ignore[arg-type]
226-
async def aexit_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any:
239+
@functools.wraps(original_aexit or original_exit) # type: ignore[arg-type]
240+
async def aexit_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any:
241+
async with self._entry_lock:
227242
self._entry_ref_count -= 1
228243
if self._entry_ref_count == 0:
229244
if original_aexit:
@@ -232,7 +247,7 @@ async def aexit_wrapper(self: "Toolset", *args: t.Any, **kwargs: t.Any) -> t.Any
232247
original_exit(self, *args, **kwargs)
233248
self._context_handle = None
234249

235-
cls.__aexit__ = aexit_wrapper # type: ignore[attr-defined]
250+
cls.__aexit__ = aexit_wrapper # type: ignore[attr-defined]
236251

237252
def get_tools(self, *, variant: str | None = None) -> list[AnyTool]:
238253
variant = variant or self.variant

0 commit comments

Comments
 (0)