@@ -177,6 +177,13 @@ def __init__(self, hs: "HomeServer"):
177177 else :
178178 self .url_previewer = None
179179
180+ # We get the media upload limits and sort them in descending order of
181+ # time period, so that we can apply some optimizations.
182+ self .media_upload_limits = hs .config .media .media_upload_limits
183+ self .media_upload_limits .sort (
184+ key = lambda limit : limit .time_period_ms , reverse = True
185+ )
186+
180187 def _start_update_recently_accessed (self ) -> Deferred :
181188 return run_as_background_process (
182189 "update_recently_accessed_media" , self ._update_recently_accessed
@@ -285,80 +292,37 @@ async def verify_can_upload(self, media_id: str, auth_user: UserID) -> None:
285292 raise NotFoundError ("Media ID has expired" )
286293
287294 @trace
288- async def update_content (
289- self ,
290- media_id : str ,
291- media_type : str ,
292- upload_name : Optional [str ],
293- content : IO ,
294- content_length : int ,
295- auth_user : UserID ,
296- ) -> None :
297- """Update the content of the given media ID.
298-
299- Args:
300- media_id: The media ID to replace.
301- media_type: The content type of the file.
302- upload_name: The name of the file, if provided.
303- content: A file like object that is the content to store
304- content_length: The length of the content
305- auth_user: The user_id of the uploader
306- """
307- file_info = FileInfo (server_name = None , file_id = media_id )
308- sha256reader = SHA256TransparentIOReader (content )
309- # This implements all of IO as it has a passthrough
310- fname = await self .media_storage .store_file (sha256reader .wrap (), file_info )
311- sha256 = sha256reader .hexdigest ()
312- should_quarantine = await self .store .get_is_hash_quarantined (sha256 )
313- logger .info ("Stored local media in file %r" , fname )
314-
315- if should_quarantine :
316- logger .warning (
317- "Media has been automatically quarantined as it matched existing quarantined media"
318- )
319-
320- await self .store .update_local_media (
321- media_id = media_id ,
322- media_type = media_type ,
323- upload_name = upload_name ,
324- media_length = content_length ,
325- user_id = auth_user ,
326- sha256 = sha256 ,
327- quarantined_by = "system" if should_quarantine else None ,
328- )
329-
330- try :
331- await self ._generate_thumbnails (None , media_id , media_id , media_type )
332- except Exception as e :
333- logger .info ("Failed to generate thumbnails: %s" , e )
334-
335- @trace
336- async def create_content (
295+ async def create_or_update_content (
337296 self ,
338297 media_type : str ,
339298 upload_name : Optional [str ],
340299 content : IO ,
341300 content_length : int ,
342301 auth_user : UserID ,
302+ media_id : Optional [str ] = None ,
343303 ) -> MXCUri :
344- """Store uploaded content for a local user and return the mxc URL
304+ """Create or update the content of the given media ID.
345305
346306 Args:
347307 media_type: The content type of the file.
348308 upload_name: The name of the file, if provided.
349309 content: A file like object that is the content to store
350310 content_length: The length of the content
351311 auth_user: The user_id of the uploader
312+ media_id: The media ID to update if provided, otherwise creates
313+ new media ID.
352314
353315 Returns:
354316 The mxc url of the stored content
355317 """
356318
357- media_id = random_string (24 )
319+ is_new_media = media_id is None
320+ if media_id is None :
321+ media_id = random_string (24 )
358322
359323 file_info = FileInfo (server_name = None , file_id = media_id )
360- # This implements all of IO as it has a passthrough
361324 sha256reader = SHA256TransparentIOReader (content )
325+ # This implements all of IO as it has a passthrough
362326 fname = await self .media_storage .store_file (sha256reader .wrap (), file_info )
363327 sha256 = sha256reader .hexdigest ()
364328 should_quarantine = await self .store .get_is_hash_quarantined (sha256 )
@@ -370,16 +334,56 @@ async def create_content(
370334 "Media has been automatically quarantined as it matched existing quarantined media"
371335 )
372336
373- await self .store .store_local_media (
374- media_id = media_id ,
375- media_type = media_type ,
376- time_now_ms = self .clock .time_msec (),
377- upload_name = upload_name ,
378- media_length = content_length ,
379- user_id = auth_user ,
380- sha256 = sha256 ,
381- quarantined_by = "system" if should_quarantine else None ,
382- )
337+ # Check that the user has not exceeded any of the media upload limits.
338+
339+ # This is the total size of media uploaded by the user in the last
340+ # `time_period_ms` milliseconds, or None if we haven't checked yet.
341+ uploaded_media_size : Optional [int ] = None
342+
343+ # Note: the media upload limits are sorted so larger time periods are
344+ # first.
345+ for limit in self .media_upload_limits :
346+ # We only need to check the amount of media uploaded by the user in
347+ # this latest (smaller) time period if the amount of media uploaded
348+ # in a previous (larger) time period is above the limit.
349+ #
350+ # This optimization means that in the common case where the user
351+ # hasn't uploaded much media, we only need to query the database
352+ # once.
353+ if (
354+ uploaded_media_size is None
355+ or uploaded_media_size + content_length > limit .max_bytes
356+ ):
357+ uploaded_media_size = await self .store .get_media_uploaded_size_for_user (
358+ user_id = auth_user .to_string (), time_period_ms = limit .time_period_ms
359+ )
360+
361+ if uploaded_media_size + content_length > limit .max_bytes :
362+ raise SynapseError (
363+ 400 , "Media upload limit exceeded" , Codes .RESOURCE_LIMIT_EXCEEDED
364+ )
365+
366+ if is_new_media :
367+ await self .store .store_local_media (
368+ media_id = media_id ,
369+ media_type = media_type ,
370+ time_now_ms = self .clock .time_msec (),
371+ upload_name = upload_name ,
372+ media_length = content_length ,
373+ user_id = auth_user ,
374+ sha256 = sha256 ,
375+ quarantined_by = "system" if should_quarantine else None ,
376+ )
377+ else :
378+ await self .store .update_local_media (
379+ media_id = media_id ,
380+ media_type = media_type ,
381+ upload_name = upload_name ,
382+ media_length = content_length ,
383+ user_id = auth_user ,
384+ sha256 = sha256 ,
385+ quarantined_by = "system" if should_quarantine else None ,
386+ )
383387
384388 try :
385389 await self ._generate_thumbnails (None , media_id , media_id , media_type )
0 commit comments