diff --git a/changelog.d/19204.feature b/changelog.d/19204.feature new file mode 100644 index 00000000000..e768b7ebb43 --- /dev/null +++ b/changelog.d/19204.feature @@ -0,0 +1 @@ +Made the local media directory optional by treating it as a storage provider. This allows off-site media storage without local cache, where media is stored directly to remote providers only, with temporary files used for thumbnail generation when needed. Contributed by Patrice Brend'amour @dr.allgood. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 7509e4d715e..7ecacd71dbe 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2092,6 +2092,15 @@ Example configuration: enable_media_repo: false ``` --- +### `enable_local_media_storage` + +*(boolean)* Enable the local on-disk media storage provider. When disabled, media is stored only in configured media_storage_providers and temporary files are used for processing. Defaults to `true`. + +Example configuration: +```yaml +enable_local_media_storage: false +``` +--- ### `media_store_path` *(string)* Directory where uploaded images and attachments are stored. Defaults to `"media_store"`. diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index bf9346995da..e6d86e2c3a9 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -2338,6 +2338,15 @@ properties: default: true examples: - false + enable_local_media_storage: + type: boolean + description: >- + Enable the local on-disk media storage provider. When disabled, media is + stored only in configured media_storage_providers and temporary files are + used for processing. + default: true + examples: + - false media_store_path: type: string description: Directory where uploaded images and attachments are stored. diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 221130b0cd1..c87442aace0 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -174,6 +174,11 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: config.get("media_store_path", "media_store") ) + # Whether to enable the local media storage provider. When disabled, + # media will only be stored in configured storage providers and temp + # files will be used for processing. + self.enable_local_media_storage = config.get("enable_local_media_storage", True) + backup_media_store_path = config.get("backup_media_store_path") synchronous_backup_media_store = config.get( diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 29c5e66ec49..f588b52f647 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -64,7 +64,10 @@ SHA256TransparentIOReader, SHA256TransparentIOWriter, ) -from synapse.media.storage_provider import StorageProviderWrapper +from synapse.media.storage_provider import ( + FileStorageProviderBackend, + StorageProviderWrapper, +) from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia @@ -142,8 +145,13 @@ def __init__(self, hs: "HomeServer"): ) storage_providers.append(provider) + # If local media storage is enabled, create the local provider + local_provider: FileStorageProviderBackend | None = None + if hs.config.media.enable_local_media_storage and self.primary_base_path: + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) + self.media_storage: MediaStorage = MediaStorage( - self.hs, self.primary_base_path, self.filepaths, storage_providers + self.hs, self.filepaths, storage_providers, local_provider ) self.clock.looping_call( @@ -1101,32 +1109,31 @@ async def generate_local_exact_thumbnail( t_type: str, url_cache: bool, ) -> tuple[str, FileInfo] | None: - input_path = await self.media_storage.ensure_media_is_in_local_cache( + async with self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", - media_id, - t_method, - t_type, - e, - ) - return None + ) as input_path: + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -1177,33 +1184,32 @@ async def generate_remote_exact_thumbnail( t_method: str, t_type: str, ) -> str | None: - input_path = await self.media_storage.ensure_media_is_in_local_cache( + async with self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", - media_id, - server_name, - t_method, - t_type, - e, - ) - return None + ) as input_path: + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None - with thumbnailer: - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -1273,151 +1279,157 @@ async def _generate_thumbnails( if not requirements: return None - input_path = await self.media_storage.ensure_media_is_in_local_cache( + async with self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) - ) - - try: - thumbnailer = Thumbnailer(input_path) - except ThumbnailError as e: - logger.warning( - "Unable to generate thumbnails for remote media %s from %s of type %s: %s", - media_id, - server_name, - media_type, - e, - ) - return None - - with thumbnailer: - m_width = thumbnailer.width - m_height = thumbnailer.height - - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, + ) as input_path: + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", + media_id, + server_name, + media_type, + e, ) return None - if thumbnailer.transpose_method is not None: - m_width, m_height = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.transpose - ) + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height - # We deduplicate the thumbnail sizes by ignoring the cropped versions if - # they have the same dimensions of a scaled one. - thumbnails: dict[tuple[int, int, str], str] = {} - for requirement in requirements: - if requirement.method == "crop": - thumbnails.setdefault( - (requirement.width, requirement.height, requirement.media_type), - requirement.method, - ) - elif requirement.method == "scale": - t_width, t_height = thumbnailer.aspect( - requirement.width, requirement.height - ) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - thumbnails[(t_width, t_height, requirement.media_type)] = ( - requirement.method + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, ) + return None - # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in thumbnails.items(): - # Generate the thumbnail - if t_method == "crop": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.crop, - t_width, - t_height, - t_type, - ) - elif t_method == "scale": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - thumbnailer.scale, - t_width, - t_height, - t_type, + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose ) - else: - logger.error("Unrecognized method: %r", t_method) - continue - if not t_byte_source: - continue + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: dict[tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + ( + requirement.width, + requirement.height, + requirement.media_type, + ), + requirement.method, + ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[(t_width, t_height, requirement.media_type)] = ( + requirement.method + ) - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - length=t_byte_source.tell(), - ), - ) + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue - async with self.media_storage.store_into_file(file_info) as (f, fname): - try: - await self.media_storage.write_to_file(t_byte_source, f) - finally: - t_byte_source.close() - - # We flush and close the file to ensure that the bytes have - # been written before getting the size. - f.flush() - f.close() - - t_len = os.path.getsize(fname) - - # Write to database - if server_name: - # Multiple remote media download requests can race (when - # using multiple media repos), so this may throw a violation - # constraint exception. If it does we'll delete the newly - # generated thumbnail from disk (as we're in the ctx - # manager). - # - # However: we've already called `finish()` so we may have - # also written to the storage providers. This is preferable - # to the alternative where we call `finish()` *after* this, - # where we could end up having an entry in the DB but fail - # to write the files to the storage providers. + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + length=t_byte_source.tell(), + ), + ) + + async with self.media_storage.store_into_file(file_info) as ( + f, + fname, + ): try: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - except Exception as e: - thumbnail_exists = ( - await self.store.get_remote_media_thumbnail( + await self.media_storage.write_to_file(t_byte_source, f) + finally: + t_byte_source.close() + + # We flush and close the file to ensure that the bytes have + # been written before getting the size. + f.flush() + f.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( server_name, media_id, + file_id, t_width, t_height, t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len ) - if not thumbnail_exists: - raise e - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len - ) return {"width": m_width, "height": m_height} diff --git a/synapse/media/media_storage.py b/synapse/media/media_storage.py index e83869bf4d6..ae651f0e087 100644 --- a/synapse/media/media_storage.py +++ b/synapse/media/media_storage.py @@ -24,6 +24,7 @@ import logging import os import shutil +import tempfile from contextlib import closing from io import BytesIO from types import TracebackType @@ -49,13 +50,13 @@ from synapse.api.errors import NotFoundError from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.opentracing import start_active_span, trace, trace_with_opname -from synapse.media._base import ThreadedFileSender +from synapse.media.storage_provider import FileStorageProviderBackend from synapse.util.clock import Clock from synapse.util.duration import Duration from synapse.util.file_consumer import BackgroundFileConsumer from ..types import JsonDict -from ._base import FileInfo, Responder +from ._base import FileInfo, Responder, ThreadedFileSender from .filepath import MediaFilePaths if TYPE_CHECKING: @@ -150,27 +151,30 @@ def __getattr__(self, attr_name: str) -> Any: class MediaStorage: - """Responsible for storing/fetching files from local sources. + """Responsible for storing/fetching files from storage providers. Args: hs - local_media_directory: Base path where we store media on disk filepaths storage_providers: List of StorageProvider that are used to fetch and store files. + local_provider: Optional local file storage provider for caching media on disk. """ def __init__( self, hs: "HomeServer", - local_media_directory: str, filepaths: MediaFilePaths, storage_providers: Sequence["StorageProvider"], + local_provider: "FileStorageProviderBackend | None" = None, ): self.hs = hs self.reactor = hs.get_reactor() - self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers + self.local_provider = local_provider + self.local_media_directory: str | None = None + if local_provider is not None: + self.local_media_directory = local_provider.base_directory self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker self.clock = hs.get_clock() @@ -205,11 +209,11 @@ async def store_into_file( """Async Context manager used to get a file like object to write into, as described by file_info. - Actually yields a 2-tuple (file, fname,), where file is a file - like object that can be written to and fname is the absolute path of file - on disk. + Actually yields a 2-tuple (file, media_filepath,), where file is a file + like object that can be written to and media_filepath is the absolute path + of file on disk. - fname can be used to read the contents from after upload, e.g. to + media_filepath can be used to read the contents from after upload, e.g. to generate thumbnails. Args: @@ -217,25 +221,33 @@ async def store_into_file( Example: - async with media_storage.store_into_file(info) as (f, fname,): + async with media_storage.store_into_file(info) as (f, media_filepath,): # .. write into f ... """ path = self._file_info_to_path(file_info) - fname = os.path.join(self.local_media_directory, path) + is_temp_file = False - dirname = os.path.dirname(fname) - os.makedirs(dirname, exist_ok=True) + if self.local_provider: + media_filepath = os.path.join(self.local_media_directory, path) # type: ignore[arg-type] + os.makedirs(os.path.dirname(media_filepath), exist_ok=True) - try: with start_active_span("writing to main media repo"): - with open(fname, "wb") as f: - yield f, fname + with open(media_filepath, "wb") as f: + yield f, media_filepath + else: + # No local provider, write to temp file + is_temp_file = True + with tempfile.NamedTemporaryFile(delete=False) as f: + media_filepath = f.name + yield cast(BinaryIO, f), media_filepath - with start_active_span("writing to other storage providers"): + # Spam check and store to other providers (runs for both local and temp file cases) + try: + with start_active_span("spam checking and writing to storage providers"): spam_check = ( await self._spam_checker_module_callbacks.check_media_file_for_spam( - ReadableFileWrapper(self.clock, fname), file_info + ReadableFileWrapper(self.clock, media_filepath), file_info ) ) if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: @@ -251,17 +263,23 @@ async def store_into_file( with start_active_span(str(provider)): await provider.store_file(path, file_info) + # If using a temp file, delete it after uploading to storage providers + if is_temp_file: + try: + os.remove(media_filepath) + except Exception: + pass + except Exception as e: try: - os.remove(fname) + os.remove(media_filepath) except Exception: pass raise e from None async def fetch_media(self, file_info: FileInfo) -> Responder | None: - """Attempts to fetch media described by file_info from the local cache - and configured storage providers. + """Attempts to fetch media described by file_info from the configured storage providers. Args: file_info: Metadata about the media file @@ -269,6 +287,18 @@ async def fetch_media(self, file_info: FileInfo) -> Responder | None: Returns: Returns a Responder if the file was found, otherwise None. """ + # URL cache files are stored locally and should not go through storage providers + if file_info.url_cache: + path = self._file_info_to_path(file_info) + if self.local_provider: + local_path = os.path.join(self.local_media_directory, path) # type: ignore[arg-type] + if os.path.isfile(local_path): + # Import here to avoid circular import + from .media_storage import FileResponder + + return FileResponder(self.hs, open(local_path, "rb")) + return None + paths = [self._file_info_to_path(file_info)] # fallback for remote thumbnails with no method in the filename @@ -283,16 +313,18 @@ async def fetch_media(self, file_info: FileInfo) -> Responder | None: ) ) - for path in paths: - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - logger.debug("responding with local file %s", local_path) - return FileResponder(self.hs, open(local_path, "rb")) - logger.debug("local file %s did not exist", local_path) + # Check local provider first, then other storage providers + if self.local_provider: + for path in paths: + res: Any = await self.local_provider.fetch(path, file_info) + if res: + logger.debug("Streaming %s from %s", path, self.local_provider) + return res + logger.debug("%s not found on %s", path, self.local_provider) for provider in self.storage_providers: for path in paths: - res: Any = await provider.fetch(path, file_info) + res = await provider.fetch(path, file_info) if res: logger.debug("Streaming %s from %s", path, provider) return res @@ -301,50 +333,93 @@ async def fetch_media(self, file_info: FileInfo) -> Responder | None: return None @trace - async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str: - """Ensures that the given file is in the local cache. Attempts to - download it from storage providers if it isn't. + @contextlib.asynccontextmanager + async def ensure_media_is_in_local_cache( + self, file_info: FileInfo + ) -> AsyncIterator[str]: + """Async context manager that ensures the given file is in the local cache. + Attempts to download it from storage providers if it isn't. + + When no local provider is configured, the file is downloaded to a temporary + location and automatically cleaned up when the context manager exits. Args: file_info - Returns: + Yields: Full path to local file + + Example: + async with media_storage.ensure_media_is_in_local_cache(file_info) as path: + # use path to read the file """ path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return local_path - - # Fallback for paths without method names - # Should be removed in the future - if file_info.thumbnail and file_info.server_name: - legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( - server_name=file_info.server_name, - file_id=file_info.file_id, - width=file_info.thumbnail.width, - height=file_info.thumbnail.height, - content_type=file_info.thumbnail.type, - ) - legacy_local_path = os.path.join(self.local_media_directory, legacy_path) - if os.path.exists(legacy_local_path): - return legacy_local_path - - dirname = os.path.dirname(local_path) - os.makedirs(dirname, exist_ok=True) - - for provider in self.storage_providers: - res: Any = await provider.fetch(path, file_info) - if res: - with res: - consumer = BackgroundFileConsumer( - open(local_path, "wb"), self.reactor - ) - await res.write_to_consumer(consumer) - await consumer.wait() - return local_path + if self.local_provider: + local_path = os.path.join(self.local_media_directory, path) # type: ignore[arg-type] + if os.path.exists(local_path): + yield local_path + return - raise NotFoundError() + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail.width, + height=file_info.thumbnail.height, + content_type=file_info.thumbnail.type, + ) + legacy_local_path = os.path.join( + self.local_media_directory, # type: ignore[arg-type] + legacy_path, + ) + if os.path.exists(legacy_local_path): + yield legacy_local_path + return + + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + for provider in self.storage_providers: + remote_res: Any = await provider.fetch(path, file_info) + if remote_res: + with remote_res: + consumer = BackgroundFileConsumer( + open(local_path, "wb"), self.reactor + ) + await remote_res.write_to_consumer(consumer) + await consumer.wait() + yield local_path + return + + raise NotFoundError() + else: + # No local provider, download to temp file and clean up after use + for provider in self.storage_providers: + res: Any = await provider.fetch(path, file_info) + if res: + temp_path = None + try: + with tempfile.NamedTemporaryFile( + delete=False, suffix=os.path.splitext(path)[1] + ) as tmp: + temp_path = tmp.name + with res: + consumer = BackgroundFileConsumer( + open(temp_path, "wb"), self.reactor + ) + await res.write_to_consumer(consumer) + await consumer.wait() + yield temp_path + finally: + if temp_path: + try: + os.remove(temp_path) + except Exception: + pass + return + + raise NotFoundError() @trace def _file_info_to_path(self, file_info: FileInfo) -> str: diff --git a/synapse/media/storage_provider.py b/synapse/media/storage_provider.py index a87ffa08926..c5faa25b964 100644 --- a/synapse/media/storage_provider.py +++ b/synapse/media/storage_provider.py @@ -31,7 +31,6 @@ from synapse.util.async_helpers import maybe_awaitable from ._base import FileInfo, Responder -from .media_storage import FileResponder logger = logging.getLogger(__name__) @@ -178,6 +177,9 @@ async def fetch(self, path: str, file_info: FileInfo) -> Responder | None: backup_fname = os.path.join(self.base_directory, path) if os.path.isfile(backup_fname): + # Import here to avoid circular import + from .media_storage import FileResponder + return FileResponder(self.hs, open(backup_fname, "rb")) return None diff --git a/synapse/media/thumbnailer.py b/synapse/media/thumbnailer.py index fd65131c63a..aae19a06013 100644 --- a/synapse/media/thumbnailer.py +++ b/synapse/media/thumbnailer.py @@ -636,9 +636,10 @@ async def _select_and_respond_with_thumbnail( # First let's check that we do actually have the original image # still. This will throw a 404 if we don't. # TODO: We should refetch the thumbnails for remote media. - await self.media_storage.ensure_media_is_in_local_cache( + async with self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) - ) + ): + pass if server_name: await self.media_repo.generate_remote_exact_thumbnail( diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index 1e849fa605c..9c4e813537e 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -49,6 +49,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: hs.config.media.media_store_path = self.primary_base_path + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) storage_providers = [ StorageProviderWrapper( FileStorageProviderBackend(hs, self.secondary_base_path), @@ -60,7 +61,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers + hs, self.filepaths, storage_providers, local_provider ) self.media_repo = hs.get_media_repository() @@ -187,6 +188,115 @@ def test_federation_etag(self) -> None: self.assertNotIn("body", channel.result) +class FederationMediaTest(unittest.FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") + self.addCleanup(shutil.rmtree, self.test_dir) + self.primary_base_path = os.path.join(self.test_dir, "primary") + self.secondary_base_path = os.path.join(self.test_dir, "secondary") + + hs.config.media.media_store_path = self.primary_base_path + + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) + storage_providers = [ + StorageProviderWrapper( + FileStorageProviderBackend(hs, self.secondary_base_path), + store_local=True, + store_remote=False, + store_synchronous=True, + ) + ] + + self.filepaths = MediaFilePaths(self.primary_base_path) + self.media_storage = MediaStorage( + hs, self.filepaths, storage_providers, local_provider + ) + self.media_repo = hs.get_media_repository() + + def test_thumbnail_download_scaled(self) -> None: + content = io.BytesIO(small_png.data) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_thumbnail", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + ) + ) + # test with an image file + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=scale", + ) + self.pump() + self.assertEqual(200, channel.code) + + content_type = channel.headers.getRawHeaders("content-type") + assert content_type is not None + assert "multipart/mixed" in content_type[0] + assert "boundary" in content_type[0] + + # extract boundary + boundary = content_type[0].split("boundary=")[1] + # split on boundary and check that json field and expected value exist + body = channel.result.get("body") + assert body is not None + stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) + found_json = any( + b"\r\nContent-Type: application/json\r\n\r\n{}" in field + for field in stripped_bytes + ) + self.assertTrue(found_json) + + # check that the png file exists and matches the expected scaled bytes + found_file = any(small_png.expected_scaled in field for field in stripped_bytes) + self.assertTrue(found_file) + + def test_thumbnail_download_cropped(self) -> None: + content = io.BytesIO(small_png.data) + content_uri = self.get_success( + self.media_repo.create_or_update_content( + "image/png", + "test_png_thumbnail", + content, + 67, + UserID.from_string("@user_id:whatever.org"), + ) + ) + # test with an image file + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=crop", + ) + self.pump() + self.assertEqual(200, channel.code) + + content_type = channel.headers.getRawHeaders("content-type") + assert content_type is not None + assert "multipart/mixed" in content_type[0] + assert "boundary" in content_type[0] + + # extract boundary + boundary = content_type[0].split("boundary=")[1] + # split on boundary and check that json field and expected value exist + body = channel.result.get("body") + assert body is not None + stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8")) + found_json = any( + b"\r\nContent-Type: application/json\r\n\r\n{}" in field + for field in stripped_bytes + ) + self.assertTrue(found_json) + + # check that the png file exists and matches the expected cropped bytes + found_file = any( + small_png.expected_cropped in field for field in stripped_bytes + ) + self.assertTrue(found_file) + + class FederationThumbnailTest(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) @@ -197,6 +307,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: hs.config.media.media_store_path = self.primary_base_path + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) storage_providers = [ StorageProviderWrapper( FileStorageProviderBackend(hs, self.secondary_base_path), @@ -208,7 +319,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers + hs, self.filepaths, storage_providers, local_provider ) self.media_repo = hs.get_media_repository() diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index e56354e0b3f..631718a3666 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -48,7 +48,10 @@ from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage, ReadableFileWrapper -from synapse.media.storage_provider import FileStorageProviderBackend +from synapse.media.storage_provider import ( + FileStorageProviderBackend, + StorageProviderWrapper, +) from synapse.media.thumbnailer import ThumbnailProvider from synapse.module_api import ModuleApi from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers @@ -77,11 +80,19 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: hs.config.media.media_store_path = self.primary_base_path - storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] + local_provider = FileStorageProviderBackend(hs, self.primary_base_path) + storage_providers = [ + StorageProviderWrapper( + FileStorageProviderBackend(hs, self.secondary_base_path), + store_local=True, + store_remote=False, + store_synchronous=True, + ), + ] self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - hs, self.primary_base_path, self.filepaths, storage_providers + hs, self.filepaths, storage_providers, local_provider ) def test_ensure_media_is_in_local_cache(self) -> None: @@ -102,29 +113,31 @@ def test_ensure_media_is_in_local_cache(self) -> None: # to the local cache. file_info = FileInfo(None, media_id) - # This uses a real blocking threadpool so we have to wait for it to be - # actually done :/ - x = defer.ensureDeferred( - self.media_storage.ensure_media_is_in_local_cache(file_info) - ) + async def test_ensure_media() -> None: + async with self.media_storage.ensure_media_is_in_local_cache( + file_info + ) as local_path: + self.assertTrue(os.path.exists(local_path)) - # Hotloop until the threadpool does its job... - self.wait_on_thread(x) + # Asserts the file is under the expected local cache directory + self.assertEqual( + os.path.commonprefix([self.primary_base_path, local_path]), + self.primary_base_path, + ) - local_path = self.get_success(x) + with open(local_path) as f: + body = f.read() - self.assertTrue(os.path.exists(local_path)) + self.assertEqual(test_body, body) - # Asserts the file is under the expected local cache directory - self.assertEqual( - os.path.commonprefix([self.primary_base_path, local_path]), - self.primary_base_path, - ) + # This uses a real blocking threadpool so we have to wait for it to be + # actually done :/ + x = defer.ensureDeferred(test_ensure_media()) - with open(local_path) as f: - body = f.read() + # Hotloop until the threadpool does its job... + self.wait_on_thread(x) - self.assertEqual(test_body, body) + self.get_success(x) @attr.s(auto_attribs=True, slots=True, frozen=True)