diff --git a/s3_storage_provider.py b/s3_storage_provider.py index 2478212..128b810 100644 --- a/s3_storage_provider.py +++ b/s3_storage_provider.py @@ -24,20 +24,15 @@ import botocore from botocore.config import Config -from twisted.internet import defer, reactor, threads +from twisted.internet import defer, reactor from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool -from synapse.logging.context import LoggingContext, make_deferred_yieldable +from synapse.logging.context import make_deferred_yieldable +from synapse.module_api import ModuleApi from synapse.rest.media.v1._base import Responder from synapse.rest.media.v1.storage_provider import StorageProvider -# Synapse 1.13.0 moved current_context to a module-level function. -try: - from synapse.logging.context import current_context -except ImportError: - current_context = LoggingContext.current_context - logger = logging.getLogger("synapse.s3") @@ -61,6 +56,7 @@ class S3StorageProviderBackend(StorageProvider): """ def __init__(self, hs, config): + self._module_api: ModuleApi = hs.get_module_api() self.cache_directory = hs.config.media.media_store_path self.bucket = config["bucket"] self.prefix = config["prefix"] @@ -124,37 +120,45 @@ def _get_s3_client(self): self._s3_client = s3 = b3_session.client("s3", **self.api_kwargs) return s3 - def store_file(self, path, file_info): + async def store_file(self, path, file_info): """See StorageProvider.store_file""" - parent_logcontext = current_context() - - def _store_file(): - with LoggingContext(parent_context=parent_logcontext): - self._get_s3_client().upload_file( - Filename=os.path.join(self.cache_directory, path), - Bucket=self.bucket, - Key=self.prefix + path, - ExtraArgs=self.extra_args, - ) - - return make_deferred_yieldable( - threads.deferToThreadPool(reactor, self._s3_pool, _store_file) + return await self._module_api.defer_to_threadpool( + self._s3_pool, + self._get_s3_client().upload_file, + Filename=os.path.join(self.cache_directory, path), + Bucket=self.bucket, + Key=self.prefix + path, + ExtraArgs=self.extra_args, ) - def fetch(self, path, file_info): + async def fetch(self, path, file_info): """See StorageProvider.fetch""" - logcontext = current_context() - d = defer.Deferred() - def _get_file(): - s3_download_task( - self._get_s3_client(), self.bucket, self.prefix + path, self.extra_args, d, logcontext + # Don't await this directly, as it will resolve only once the streaming + # download from S3 is concluded. Before that happens, we want to pass + # execution back to Synapse to stream the file's chunks. + # + # We do, however, need to wrap in `run_in_background` to ensure that the + # coroutine returned by `defer_to_threadpool` is used, and therefore + # actually run. + self._module_api.run_in_background( + self._module_api.defer_to_threadpool( + self._s3_pool, + s3_download_task, + self._get_s3_client(), + self.bucket, + self.prefix + path, + self.extra_args, + d, ) + ) - self._s3_pool.callInThread(_get_file) - return make_deferred_yieldable(d) + # DO await on `d`, as it will resolve once a connection to S3 has been + # opened. We only want to return to Synapse once we can start streaming + # chunks. + return await make_deferred_yieldable(d) @staticmethod def parse_config(config): @@ -202,7 +206,7 @@ def parse_config(config): return result -def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcontext): +def s3_download_task(s3_client, bucket, key, extra_args, deferred): """Attempts to download a file from S3. Args: @@ -212,35 +216,35 @@ def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcon deferred (Deferred[_S3Responder|None]): If file exists resolved with an _S3Responder instance, if it doesn't exist then resolves with None. - parent_logcontext (LoggingContext): the logcontext to report logs and metrics - against. + + Returns: + A deferred which resolves to an _S3Responder if the file exists. + Otherwise the deferred fails. """ - with LoggingContext(parent_context=parent_logcontext): - logger.info("Fetching %s from S3", key) - - try: - if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args: - resp = s3_client.get_object( - Bucket=bucket, - Key=key, - SSECustomerKey=extra_args["SSECustomerKey"], - SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"], - ) - else: - resp = s3_client.get_object(Bucket=bucket, Key=key) - - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] in ("404", "NoSuchKey",): - logger.info("Media %s not found in S3", key) - reactor.callFromThread(deferred.callback, None) - return + logger.info("Fetching %s from S3", key) + + try: + if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args: + resp = s3_client.get_object( + Bucket=bucket, + Key=key, + SSECustomerKey=extra_args["SSECustomerKey"], + SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"], + ) + else: + resp = s3_client.get_object(Bucket=bucket, Key=key) - reactor.callFromThread(deferred.errback, Failure()) + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] in ("404", "NoSuchKey",): + logger.info("Media %s not found in S3", key) return - producer = _S3Responder() - reactor.callFromThread(deferred.callback, producer) - _stream_to_producer(reactor, producer, resp["Body"], timeout=90.0) + reactor.callFromThread(deferred.errback, Failure()) + return + + producer = _S3Responder() + reactor.callFromThread(deferred.callback, producer) + _stream_to_producer(reactor, producer, resp["Body"], timeout=90.0) def _stream_to_producer(reactor, producer, body, status=None, timeout=None):