Skip to content
116 changes: 60 additions & 56 deletions s3_storage_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,15 @@
import botocore
from botocore.config import Config

from twisted.internet import defer, reactor, threads
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool

from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest.media.v1._base import Responder
from synapse.rest.media.v1.storage_provider import StorageProvider

# Synapse 1.13.0 moved current_context to a module-level function.
try:
from synapse.logging.context import current_context
except ImportError:
current_context = LoggingContext.current_context

logger = logging.getLogger("synapse.s3")


Expand All @@ -61,6 +56,7 @@ class S3StorageProviderBackend(StorageProvider):
"""

def __init__(self, hs, config):
self._module_api: ModuleApi = hs.get_module_api()
self.cache_directory = hs.config.media.media_store_path
self.bucket = config["bucket"]
self.prefix = config["prefix"]
Expand Down Expand Up @@ -124,37 +120,45 @@ def _get_s3_client(self):
self._s3_client = s3 = b3_session.client("s3", **self.api_kwargs)
return s3

def store_file(self, path, file_info):
async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""

parent_logcontext = current_context()

def _store_file():
with LoggingContext(parent_context=parent_logcontext):
self._get_s3_client().upload_file(
Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket,
Key=self.prefix + path,
ExtraArgs=self.extra_args,
)

return make_deferred_yieldable(
threads.deferToThreadPool(reactor, self._s3_pool, _store_file)
return await self._module_api.defer_to_threadpool(
self._s3_pool,
self._get_s3_client().upload_file,
Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket,
Key=self.prefix + path,
ExtraArgs=self.extra_args,
)

def fetch(self, path, file_info):
async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
logcontext = current_context()

d = defer.Deferred()

def _get_file():
s3_download_task(
self._get_s3_client(), self.bucket, self.prefix + path, self.extra_args, d, logcontext
# Don't await this directly, as it will resolve only once the streaming
# download from S3 is concluded. Before that happens, we want to pass
# execution back to Synapse to stream the file's chunks.
#
# We do, however, need to wrap in `run_in_background` to ensure that the
# coroutine returned by `defer_to_threadpool` is used, and therefore
# actually run.
self._module_api.run_in_background(
self._module_api.defer_to_threadpool(
self._s3_pool,
s3_download_task,
self._get_s3_client(),
self.bucket,
self.prefix + path,
self.extra_args,
d,
)
)

self._s3_pool.callInThread(_get_file)
return make_deferred_yieldable(d)
# DO await on `d`, as it will resolve once a connection to S3 has been
# opened. We only want to return to Synapse once we can start streaming
# chunks.
Comment on lines +158 to +160

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cross-linking internal discussion where the cause of the Exception: Timed out waiting to resume was figured out.

return await make_deferred_yieldable(d)

@staticmethod
def parse_config(config):
Expand Down Expand Up @@ -202,7 +206,7 @@ def parse_config(config):
return result


def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcontext):
def s3_download_task(s3_client, bucket, key, extra_args, deferred):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only real changes here are:

  1. Remove parent_logcontext.
  2. Removing with LoggingContext ... and de-indenting all of the code underneath it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these changes make sense. s3_download_task will use whatever the caller logcontext is.

And I think we maintain the logcontext when calling s3_download_task ✅ (at-least with the suggested patterns).

"""Attempts to download a file from S3.

Args:
Expand All @@ -212,35 +216,35 @@ def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcon
deferred (Deferred[_S3Responder|None]): If file exists
resolved with an _S3Responder instance, if it doesn't
exist then resolves with None.
parent_logcontext (LoggingContext): the logcontext to report logs and metrics
against.

Returns:
A deferred which resolves to an _S3Responder if the file exists.
Otherwise the deferred fails.
"""
with LoggingContext(parent_context=parent_logcontext):
logger.info("Fetching %s from S3", key)

try:
if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
resp = s3_client.get_object(
Bucket=bucket,
Key=key,
SSECustomerKey=extra_args["SSECustomerKey"],
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
)
else:
resp = s3_client.get_object(Bucket=bucket, Key=key)

except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
logger.info("Media %s not found in S3", key)
reactor.callFromThread(deferred.callback, None)
return
logger.info("Fetching %s from S3", key)

try:
if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
resp = s3_client.get_object(
Bucket=bucket,
Key=key,
SSECustomerKey=extra_args["SSECustomerKey"],
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
)
else:
resp = s3_client.get_object(Bucket=bucket, Key=key)

reactor.callFromThread(deferred.errback, Failure())
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
logger.info("Media %s not found in S3", key)
return

producer = _S3Responder()
reactor.callFromThread(deferred.callback, producer)
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.0)
reactor.callFromThread(deferred.errback, Failure())
return

producer = _S3Responder()
reactor.callFromThread(deferred.callback, producer)
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.0)


def _stream_to_producer(reactor, producer, body, status=None, timeout=None):
Expand Down