Skip to content
85 changes: 35 additions & 50 deletions s3_storage_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,10 @@
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool

from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.module_api import run_in_background
from synapse.rest.media.v1._base import Responder
from synapse.rest.media.v1.storage_provider import StorageProvider

# Synapse 1.13.0 moved current_context to a module-level function.
try:
from synapse.logging.context import current_context
except ImportError:
current_context = LoggingContext.current_context

logger = logging.getLogger("synapse.s3")


Expand Down Expand Up @@ -121,34 +115,29 @@ def _get_s3_client(self):
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""

parent_logcontext = current_context()

def _store_file():
with LoggingContext(parent_context=parent_logcontext):
self._get_s3_client().upload_file(
Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket,
Key=self.prefix + path,
ExtraArgs=self.extra_args,
)

return make_deferred_yieldable(
threads.deferToThreadPool(reactor, self._s3_pool, _store_file)
self._get_s3_client().upload_file(
Filename=os.path.join(self.cache_directory, path),
Bucket=self.bucket,
Key=self.prefix + path,
ExtraArgs=self.extra_args,
)

return run_in_background(
threads.deferToThreadPool, reactor, self._s3_pool, _store_file
)

def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
logcontext = current_context()

d = defer.Deferred()

def _get_file():
s3_download_task(
self._get_s3_client(), self.bucket, self.prefix + path, self.extra_args, d, logcontext
self._get_s3_client(), self.bucket, self.prefix + path, self.extra_args, d
)

self._s3_pool.callInThread(_get_file)
return make_deferred_yieldable(d)
run_in_background(self._s3_pool.callInThread, _get_file)
return d

@staticmethod
def parse_config(config):
Expand Down Expand Up @@ -196,7 +185,7 @@ def parse_config(config):
return result


def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcontext):
def s3_download_task(s3_client, bucket, key, extra_args, deferred):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only real changes here are:

  1. Remove parent_logcontext.
  2. Removing with LoggingContext ... and de-indenting all of the code underneath it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these changes make sense. s3_download_task will use whatever the caller logcontext is.

And I think we maintain the logcontext when calling s3_download_task ✅ (at-least with the suggested patterns).

"""Attempts to download a file from S3.

Args:
Expand All @@ -206,35 +195,31 @@ def s3_download_task(s3_client, bucket, key, extra_args, deferred, parent_logcon
deferred (Deferred[_S3Responder|None]): If file exists
resolved with an _S3Responder instance, if it doesn't
exist then resolves with None.
parent_logcontext (LoggingContext): the logcontext to report logs and metrics
against.
"""
with LoggingContext(parent_context=parent_logcontext):
logger.info("Fetching %s from S3", key)

try:
if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
resp = s3_client.get_object(
Bucket=bucket,
Key=key,
SSECustomerKey=extra_args["SSECustomerKey"],
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
)
else:
resp = s3_client.get_object(Bucket=bucket, Key=key)

except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
logger.info("Media %s not found in S3", key)
reactor.callFromThread(deferred.callback, None)
return
logger.info("Fetching %s from S3", key)

reactor.callFromThread(deferred.errback, Failure())
try:
if "SSECustomerKey" in extra_args and "SSECustomerAlgorithm" in extra_args:
resp = s3_client.get_object(
Bucket=bucket,
Key=key,
SSECustomerKey=extra_args["SSECustomerKey"],
SSECustomerAlgorithm=extra_args["SSECustomerAlgorithm"],
)
else:
resp = s3_client.get_object(Bucket=bucket, Key=key)

except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] in ("404", "NoSuchKey",):
logger.info("Media %s not found in S3", key)
return

producer = _S3Responder()
reactor.callFromThread(deferred.callback, producer)
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.0)
reactor.callFromThread(deferred.errback, Failure())
return

producer = _S3Responder()
reactor.callFromThread(deferred.callback, producer)
_stream_to_producer(reactor, producer, resp["Body"], timeout=90.0)


def _stream_to_producer(reactor, producer, body, status=None, timeout=None):
Expand Down
Loading