Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 47 additions & 52 deletions src/aiida/transports/plugins/ssh_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(self, *args, **kwargs):
# a computer with core.ssh_async transport plugin should be configured before any instantiation.
self.machine = kwargs.pop('host', kwargs.pop('machine'))
self._max_io_allowed = kwargs.pop('max_io_allowed', self._DEFAULT_max_io_allowed)
self._semaphore = asyncio.Semaphore(self._max_io_allowed)
self.script_before = kwargs.pop('script_before', 'None')

if kwargs.get('backend') == 'openssh':
Expand All @@ -145,20 +146,10 @@ def __init__(self, *args, **kwargs):

self.async_backend = _AsyncSSH(self.machine, self.logger, self._bash_command_str) # type: ignore[assignment]

self._concurrent_io = 0

@property
def max_io_allowed(self):
return self._max_io_allowed

async def _lock(self, sleep_time=0.5):
while self._concurrent_io >= self.max_io_allowed:
await asyncio.sleep(sleep_time)
self._concurrent_io += 1

async def _unlock(self):
self._concurrent_io -= 1

async def open_async(self):
"""Open the transport.
This plugin supports running scripts before and during the connection.
Expand Down Expand Up @@ -316,14 +307,17 @@ async def getfile_async(
if os.path.isfile(localpath) and not overwrite:
raise OSError('Destination already exists: not overwriting it')

try:
await self._lock()
await self.async_backend.get(
remotepath=remotepath, localpath=localpath, dereference=dereference, preserve=preserve, recursive=False
)
await self._unlock()
except OSError as exc:
raise OSError(f'Error while downloading file {remotepath}: {exc}')
async with self._semaphore:
try:
await self.async_backend.get(
remotepath=remotepath,
localpath=localpath,
dereference=dereference,
preserve=preserve,
recursive=False,
)
except OSError as exc:
raise OSError(f'Error while downloading file {remotepath}: {exc}')

async def gettree_async(
self,
Expand Down Expand Up @@ -383,19 +377,18 @@ async def gettree_async(

content_list = await self.listdir_async(remotepath)
for content_ in content_list:
try:
await self._lock()
parentpath = str(PurePath(remotepath) / content_)
await self.async_backend.get(
remotepath=parentpath,
localpath=localpath,
dereference=dereference,
preserve=preserve,
recursive=True,
)
await self._unlock()
except OSError as exc:
raise OSError(f'Error while downloading file {parentpath}: {exc}')
parentpath = str(PurePath(remotepath) / content_)
async with self._semaphore:
try:
await self.async_backend.get(
remotepath=parentpath,
localpath=localpath,
dereference=dereference,
preserve=preserve,
recursive=True,
)
except OSError as exc:
raise OSError(f'Error while downloading file {parentpath}: {exc}')

async def put_async(
self,
Expand Down Expand Up @@ -528,14 +521,17 @@ async def putfile_async(
if await self.isfile_async(remotepath) and not overwrite:
raise OSError('Destination already exists: not overwriting it')

try:
await self._lock()
await self.async_backend.put(
localpath=localpath, remotepath=remotepath, dereference=dereference, preserve=preserve, recursive=False
)
await self._unlock()
except OSError as exc:
raise OSError(f'Error while uploading file {localpath}: {exc}')
async with self._semaphore:
try:
await self.async_backend.put(
localpath=localpath,
remotepath=remotepath,
dereference=dereference,
preserve=preserve,
recursive=False,
)
except OSError as exc:
raise OSError(f'Error while uploading file {localpath}: {exc}')
Comment on lines +524 to +534
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just look at the code, the equivalent change should be put async with _semaphore inside the try block right before the put().
But I think it is a correct change.

pinning @khsrali for review since you implement it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I overlooked what @danielhollas already wrote: "nd fixes a bug since the number of file IO connections was not decreased when an OsError exception was thrown.".


async def puttree_async(
self,
Expand Down Expand Up @@ -598,19 +594,18 @@ async def puttree_async(
# Or to put and rename the parent folder at the same time
content_list = os.listdir(localpath)
for content_ in content_list:
try:
await self._lock()
parentpath = str(PurePath(localpath) / content_)
await self.async_backend.put(
localpath=parentpath,
remotepath=remotepath,
dereference=dereference,
preserve=preserve,
recursive=True,
)
await self._unlock()
except OSError as exc:
raise OSError(f'Error while uploading file {parentpath}: {exc}')
parentpath = str(PurePath(localpath) / content_)
async with self._semaphore:
try:
await self.async_backend.put(
localpath=parentpath,
remotepath=remotepath,
dereference=dereference,
preserve=preserve,
recursive=True,
)
except OSError as exc:
raise OSError(f'Error while uploading file {parentpath}: {exc}')

async def copy_async(
self,
Expand Down