diff --git a/adlfs/spec.py b/adlfs/spec.py index f87819bc..ce755214 100644 --- a/adlfs/spec.py +++ b/adlfs/spec.py @@ -4,12 +4,14 @@ from __future__ import absolute_import, division, print_function import asyncio +import contextlib from glob import has_magic import io import logging import os import warnings import weakref +from typing import Optional from azure.core.exceptions import ( ClientAuthenticationError, @@ -39,6 +41,7 @@ get_blob_metadata, close_service_client, close_container_client, + _nullcontext, ) from datetime import datetime, timedelta @@ -354,6 +357,12 @@ class AzureBlobFileSystem(AsyncFileSystem): default_cache_type: string ('bytes') If given, the default cache_type value used for "open()". Set to none if no caching is desired. Docs in fsspec + max_concurrency : int, optional + The maximum number of BlobClient connections that can exist simultaneously for this + filesystem instance. By default, there is no limit. Setting this might be helpful if + you have a very large number of small, independent blob operations to perform. By + default a single BlobClient is created per blob, which might cause high memory usage + and clogging the asyncio event loop as many instances are created and quickly destroyed. Pass on to fsspec: @@ -412,6 +421,7 @@ def __init__( asynchronous: bool = False, default_fill_cache: bool = True, default_cache_type: str = "bytes", + max_concurrency: Optional[int] = None, **kwargs, ): super_kwargs = { @@ -440,6 +450,13 @@ def __init__( self.blocksize = blocksize self.default_fill_cache = default_fill_cache self.default_cache_type = default_cache_type + self.max_concurrency = max_concurrency + + if self.max_concurrency is None: + self._blob_client_semaphore = _nullcontext() + else: + self._blob_client_semaphore = asyncio.Semaphore(max_concurrency) + if ( self.credential is None and self.account_key is None @@ -452,6 +469,7 @@ def __init__( ) = self._get_credential_from_service_principal() else: self.sync_credential = None + self.do_connect() weakref.finalize(self, sync, self.loop, close_service_client, self) @@ -491,6 +509,15 @@ def _strip_protocol(cls, path: str): logger.debug(f"_strip_protocol({path}) = {ops}") return ops["path"] + @contextlib.asynccontextmanager + async def _get_blob_client(self, container_name, path): + """ + Get a blob client, respecting `self.max_concurrency` if set. + """ + async with self._blob_client_semaphore: + async with self.service_client.get_blob_client(container_name, path) as bc: + yield bc + def _get_credential_from_service_principal(self): """ Create a Credential for authentication. This can include a TokenCredential @@ -1332,9 +1359,7 @@ async def _isfile(self, path): return False else: try: - async with self.service_client.get_blob_client( - container_name, path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: props = await bc.get_blob_properties() if props["metadata"]["is_directory"] == "false": return True @@ -1393,7 +1418,7 @@ async def _exists(self, path): # Empty paths exist by definition return True - async with self.service_client.get_blob_client(container_name, path) as bc: + async with self._get_blob_client(container_name, path) as bc: if await bc.exists(): return True @@ -1411,9 +1436,7 @@ async def _exists(self, path): async def _pipe_file(self, path, value, overwrite=True, **kwargs): """Set the bytes of given file""" container_name, path = self.split_path(path) - async with self.service_client.get_blob_client( - container=container_name, blob=path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: result = await bc.upload_blob( data=value, overwrite=overwrite, metadata={"is_directory": "false"} ) @@ -1430,9 +1453,7 @@ async def _cat_file(self, path, start=None, end=None, **kwargs): else: length = None container_name, path = self.split_path(path) - async with self.service_client.get_blob_client( - container=container_name, blob=path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: try: stream = await bc.download_blob(offset=start, length=length) except ResourceNotFoundError as e: @@ -1494,7 +1515,7 @@ async def _url(self, path, expires=3600, **kwargs): expiry=datetime.utcnow() + timedelta(seconds=expires), ) - async with self.service_client.get_blob_client(container_name, blob) as bc: + async with self._get_blob_client(container_name, blob) as bc: url = f"{bc.url}?{sas_token}" return url @@ -1569,9 +1590,7 @@ async def _put_file( else: try: with open(lpath, "rb") as f1: - async with self.service_client.get_blob_client( - container_name, path - ) as bc: + async with self._get_blob_client(container_name, path) as bc: await bc.upload_blob( f1, overwrite=overwrite, @@ -1596,14 +1615,10 @@ async def _cp_file(self, path1, path2, **kwargs): container1, path1 = self.split_path(path1, delimiter="/") container2, path2 = self.split_path(path2, delimiter="/") - cc1 = self.service_client.get_container_client(container1) - blobclient1 = cc1.get_blob_client(blob=path1) - if container1 == container2: - blobclient2 = cc1.get_blob_client(blob=path2) - else: - cc2 = self.service_client.get_container_client(container2) - blobclient2 = cc2.get_blob_client(blob=path2) - await blobclient2.start_copy_from_url(blobclient1.url) + # TODO: this could cause a deadlock. Can we protect the user? + async with self._get_blob_client(container1, path1) as blobclient1: + async with self._get_blob_client(container2, path1) as blobclient2: + await blobclient2.start_copy_from_url(blobclient1.url) self.invalidate_cache(container1) self.invalidate_cache(container2) @@ -1623,7 +1638,7 @@ async def _get_file( """ Copy single file remote to local """ container_name, path = self.split_path(rpath, delimiter=delimiter) try: - async with self.service_client.get_blob_client( + async with self._get_blob_client( container_name, path.rstrip(delimiter) ) as bc: with open(lpath, "wb") as my_blob: @@ -1645,7 +1660,7 @@ def getxattr(self, path, attr): async def _setxattrs(self, rpath, **kwargs): container_name, path = self.split_path(rpath) try: - async with self.service_client.get_blob_client(container_name, path) as bc: + async with self._get_blob_client(container_name, path) as bc: await bc.set_blob_metadata(metadata=kwargs) self.invalidate_cache(self._parent(rpath)) except Exception as e: diff --git a/adlfs/tests/test_spec.py b/adlfs/tests/test_spec.py index 99bf8760..c605afa9 100644 --- a/adlfs/tests/test_spec.py +++ b/adlfs/tests/test_spec.py @@ -1,5 +1,7 @@ +import asyncio import os import tempfile +from unittest import mock import datetime import dask.dataframe as dd from fsspec.implementations.local import LocalFileSystem @@ -1348,3 +1350,17 @@ def test_find_with_prefix(storage): assert test_1s == [test_bucket_name + "/prefixes/test_1"] + [ test_bucket_name + f"/prefixes/test_{cursor}" for cursor in range(10, 20) ] + + +def test_max_concurrency(storage): + fs = AzureBlobFileSystem( + account_name=storage.account_name, connection_string=CONN_STR, max_concurrency=2 + ) + + assert isinstance(fs._blob_client_semaphore, asyncio.Semaphore) + + fs._blob_client_semaphore = mock.MagicMock(fs._blob_client_semaphore) + path = {f"/data/{i}": b"value" for i in range(10)} + fs.pipe(path) + + assert fs._blob_client_semaphore.__aenter__.call_count == 10 diff --git a/adlfs/utils.py b/adlfs/utils.py index ae42eba0..81b7c3dc 100644 --- a/adlfs/utils.py +++ b/adlfs/utils.py @@ -1,3 +1,7 @@ +import contextlib +import sys + + async def filter_blobs(blobs, target_path, delimiter="/"): """ Filters out blobs that do not come from target_path @@ -43,3 +47,14 @@ async def close_container_client(file_obj): AzureBlobFile objects """ await file_obj.container_client.close() + + +if sys.version_info < (3, 10): + # PYthon 3.10 added support for async to nullcontext + @contextlib.asynccontextmanager + async def _nullcontext(*args): + yield + + +else: + _nullcontext = contextlib.nullcontext