Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,20 @@ def __init__(self, hs: "HomeServer"):
)
storage_providers.append(provider)

# If we have a local media directory, add it as a storage provider
if self.primary_base_path:
from synapse.media.storage_provider import FileStorageProviderBackend, StorageProviderWrapper
backend = FileStorageProviderBackend(hs, self.primary_base_path)
local_wrapper = StorageProviderWrapper(
backend,
store_local=True,
store_remote=False,
store_synchronous=True,
)
storage_providers.insert(0, local_wrapper)

self.media_storage: MediaStorage = MediaStorage(
self.hs, self.primary_base_path, self.filepaths, storage_providers
self.hs, self.filepaths, storage_providers
)

self.clock.looping_call(
Expand Down
196 changes: 121 additions & 75 deletions synapse/media/media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import logging
import os
import shutil
import tempfile
from contextlib import closing
from io import BytesIO
from types import TracebackType
Expand Down Expand Up @@ -149,27 +150,31 @@ def __getattr__(self, attr_name: str) -> Any:


class MediaStorage:
"""Responsible for storing/fetching files from local sources.
"""Responsible for storing/fetching files from storage providers.

Args:
hs
local_media_directory: Base path where we store media on disk
filepaths
storage_providers: List of StorageProvider that are used to fetch and store files.
"""

def __init__(
self,
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"],
):
self.hs = hs
self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
self.storage_providers = list(storage_providers)
self.local_provider = None
self.local_media_directory = None
for provider in self.storage_providers:
if isinstance(provider.backend, FileStorageProviderBackend):
self.local_provider = provider
self.local_media_directory = provider.backend.base_directory
break
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.clock = hs.get_clock()

Expand Down Expand Up @@ -221,46 +226,75 @@ async def store_into_file(
"""

path = self._file_info_to_path(file_info)
fname = os.path.join(self.local_media_directory, path)

dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)

try:
with start_active_span("writing to main media repo"):
with open(fname, "wb") as f:
yield f, fname
if self.local_provider:
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)

with start_active_span("writing to other storage providers"):
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
try:
with start_active_span("writing to main media repo"):
with open(fname, "wb") as f:
yield f, fname

with start_active_span("spam checking and writing to other storage providers"):
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
)
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])

for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)

except Exception as e:
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])

for provider in self.storage_providers:
if provider is not self.local_provider:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)

except Exception as e:
try:
os.remove(fname)
except Exception:
pass

raise e from None
else:
# No local provider, write to temp file
with tempfile.NamedTemporaryFile(delete=False) as f:
fname = f.name
yield f, fname

try:
os.remove(fname)
except Exception:
pass
with start_active_span("spam checking and writing to storage providers"):
spam_check = (
await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker")
raise SpamMediaException(errcode=spam_check[0])

for provider in self.storage_providers:
with start_active_span(str(provider)):
await provider.store_file(path, file_info)

except Exception as e:
try:
os.remove(fname)
except Exception:
pass

raise e from None
raise e from None

async def fetch_media(self, file_info: FileInfo) -> Responder | None:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
"""Attempts to fetch media described by file_info from the configured storage providers.

Args:
file_info: Metadata about the media file
Expand All @@ -282,13 +316,6 @@ async def fetch_media(self, file_info: FileInfo) -> Responder | None:
)
)

for path in paths:
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
logger.debug("responding with local file %s", local_path)
return FileResponder(self.hs, open(local_path, "rb"))
logger.debug("local file %s did not exist", local_path)

for provider in self.storage_providers:
for path in paths:
res: Any = await provider.fetch(path, file_info)
Expand All @@ -311,39 +338,58 @@ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
return local_path

# Fallback for paths without method names
# Should be removed in the future
if file_info.thumbnail and file_info.server_name:
legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
width=file_info.thumbnail.width,
height=file_info.thumbnail.height,
content_type=file_info.thumbnail.type,
)
legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
if os.path.exists(legacy_local_path):
return legacy_local_path
if self.local_provider:
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
return local_path

dirname = os.path.dirname(local_path)
os.makedirs(dirname, exist_ok=True)
# Fallback for paths without method names
# Should be removed in the future
if file_info.thumbnail and file_info.server_name:
legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
width=file_info.thumbnail.width,
height=file_info.thumbnail.height,
content_type=file_info.thumbnail.type,
)
legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
if os.path.exists(legacy_local_path):
return legacy_local_path

for provider in self.storage_providers:
res: Any = await provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
return local_path
dirname = os.path.dirname(local_path)
os.makedirs(dirname, exist_ok=True)

raise NotFoundError()
for provider in self.storage_providers:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could make use of the fetch_media method here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really. They have a different purpose. The ensure_media_is_in_local_cache is used to ensure a local copy so a thumbnail can be generated. While the fetch_media returns a Responder (to stream the file).

if provider is self.local_provider:
continue
res: Any = await provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
return local_path

raise NotFoundError()
else:
# No local provider, download to temp
for provider in self.storage_providers:
res: Any = await provider.fetch(path, file_info)
if res:
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, os.path.basename(path))
with res:
consumer = BackgroundFileConsumer(
open(temp_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
return temp_path

raise NotFoundError()

@trace
def _file_info_to_path(self, file_info: FileInfo) -> str:
Expand Down
7 changes: 5 additions & 2 deletions tests/media/test_media_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,14 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:

hs.config.media.media_store_path = self.primary_base_path

storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
storage_providers = [
FileStorageProviderBackend(hs, self.primary_base_path),
FileStorageProviderBackend(hs, self.secondary_base_path),
]

self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
hs, self.primary_base_path, self.filepaths, storage_providers
hs, self.filepaths, storage_providers
)

def test_ensure_media_is_in_local_cache(self) -> None:
Expand Down