From f9a35037363f58fdcd2a46b5b35cc43fe2c7492a Mon Sep 17 00:00:00 2001 From: Daniel Hollas Date: Thu, 25 Sep 2025 17:22:58 +0100 Subject: [PATCH] async_ssh: Use async.Semaphore instead of manual locking --- src/aiida/transports/plugins/ssh_async.py | 99 +++++++++++------------ 1 file changed, 47 insertions(+), 52 deletions(-) diff --git a/src/aiida/transports/plugins/ssh_async.py b/src/aiida/transports/plugins/ssh_async.py index 37f531903c..da2470b49d 100644 --- a/src/aiida/transports/plugins/ssh_async.py +++ b/src/aiida/transports/plugins/ssh_async.py @@ -133,6 +133,7 @@ def __init__(self, *args, **kwargs): # a computer with core.ssh_async transport plugin should be configured before any instantiation. self.machine = kwargs.pop('host', kwargs.pop('machine')) self._max_io_allowed = kwargs.pop('max_io_allowed', self._DEFAULT_max_io_allowed) + self._semaphore = asyncio.Semaphore(self._max_io_allowed) self.script_before = kwargs.pop('script_before', 'None') if kwargs.get('backend') == 'openssh': @@ -145,20 +146,10 @@ def __init__(self, *args, **kwargs): self.async_backend = _AsyncSSH(self.machine, self.logger, self._bash_command_str) # type: ignore[assignment] - self._concurrent_io = 0 - @property def max_io_allowed(self): return self._max_io_allowed - async def _lock(self, sleep_time=0.5): - while self._concurrent_io >= self.max_io_allowed: - await asyncio.sleep(sleep_time) - self._concurrent_io += 1 - - async def _unlock(self): - self._concurrent_io -= 1 - async def open_async(self): """Open the transport. This plugin supports running scripts before and during the connection. @@ -316,14 +307,17 @@ async def getfile_async( if os.path.isfile(localpath) and not overwrite: raise OSError('Destination already exists: not overwriting it') - try: - await self._lock() - await self.async_backend.get( - remotepath=remotepath, localpath=localpath, dereference=dereference, preserve=preserve, recursive=False - ) - await self._unlock() - except OSError as exc: - raise OSError(f'Error while downloading file {remotepath}: {exc}') + async with self._semaphore: + try: + await self.async_backend.get( + remotepath=remotepath, + localpath=localpath, + dereference=dereference, + preserve=preserve, + recursive=False, + ) + except OSError as exc: + raise OSError(f'Error while downloading file {remotepath}: {exc}') async def gettree_async( self, @@ -383,19 +377,18 @@ async def gettree_async( content_list = await self.listdir_async(remotepath) for content_ in content_list: - try: - await self._lock() - parentpath = str(PurePath(remotepath) / content_) - await self.async_backend.get( - remotepath=parentpath, - localpath=localpath, - dereference=dereference, - preserve=preserve, - recursive=True, - ) - await self._unlock() - except OSError as exc: - raise OSError(f'Error while downloading file {parentpath}: {exc}') + parentpath = str(PurePath(remotepath) / content_) + async with self._semaphore: + try: + await self.async_backend.get( + remotepath=parentpath, + localpath=localpath, + dereference=dereference, + preserve=preserve, + recursive=True, + ) + except OSError as exc: + raise OSError(f'Error while downloading file {parentpath}: {exc}') async def put_async( self, @@ -528,14 +521,17 @@ async def putfile_async( if await self.isfile_async(remotepath) and not overwrite: raise OSError('Destination already exists: not overwriting it') - try: - await self._lock() - await self.async_backend.put( - localpath=localpath, remotepath=remotepath, dereference=dereference, preserve=preserve, recursive=False - ) - await self._unlock() - except OSError as exc: - raise OSError(f'Error while uploading file {localpath}: {exc}') + async with self._semaphore: + try: + await self.async_backend.put( + localpath=localpath, + remotepath=remotepath, + dereference=dereference, + preserve=preserve, + recursive=False, + ) + except OSError as exc: + raise OSError(f'Error while uploading file {localpath}: {exc}') async def puttree_async( self, @@ -598,19 +594,18 @@ async def puttree_async( # Or to put and rename the parent folder at the same time content_list = os.listdir(localpath) for content_ in content_list: - try: - await self._lock() - parentpath = str(PurePath(localpath) / content_) - await self.async_backend.put( - localpath=parentpath, - remotepath=remotepath, - dereference=dereference, - preserve=preserve, - recursive=True, - ) - await self._unlock() - except OSError as exc: - raise OSError(f'Error while uploading file {parentpath}: {exc}') + parentpath = str(PurePath(localpath) / content_) + async with self._semaphore: + try: + await self.async_backend.put( + localpath=parentpath, + remotepath=remotepath, + dereference=dereference, + preserve=preserve, + recursive=True, + ) + except OSError as exc: + raise OSError(f'Error while uploading file {parentpath}: {exc}') async def copy_async( self,