|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import enum |
14 | 15 | import logging |
15 | 16 | from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple |
16 | 17 |
|
|
20 | 21 | from synapse.api.errors import SynapseError |
21 | 22 | from synapse.events import EventBase, relation_from_event |
22 | 23 | from synapse.logging.opentracing import trace |
23 | | -from synapse.storage.databases.main.relations import _RelatedEvent |
| 24 | +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent |
24 | 25 | from synapse.streams.config import PaginationConfig |
25 | 26 | from synapse.types import JsonDict, Requester, StreamToken, UserID |
26 | 27 | from synapse.visibility import filter_events_for_client |
|
32 | 33 | logger = logging.getLogger(__name__) |
33 | 34 |
|
34 | 35 |
|
| 36 | +class ThreadsListInclude(str, enum.Enum): |
| 37 | + """Valid values for the 'include' flag of /threads.""" |
| 38 | + |
| 39 | + all = "all" |
| 40 | + participated = "participated" |
| 41 | + |
| 42 | + |
35 | 43 | @attr.s(slots=True, frozen=True, auto_attribs=True) |
36 | 44 | class _ThreadAggregation: |
37 | 45 | # The latest event in the thread. |
@@ -482,3 +490,79 @@ async def get_bundled_aggregations( |
482 | 490 | results.setdefault(event_id, BundledAggregations()).replace = edit |
483 | 491 |
|
484 | 492 | return results |
| 493 | + |
| 494 | + async def get_threads( |
| 495 | + self, |
| 496 | + requester: Requester, |
| 497 | + room_id: str, |
| 498 | + include: ThreadsListInclude, |
| 499 | + limit: int = 5, |
| 500 | + from_token: Optional[ThreadsNextBatch] = None, |
| 501 | + ) -> JsonDict: |
| 502 | + """Get related events of a event, ordered by topological ordering. |
| 503 | +
|
| 504 | + Args: |
| 505 | + requester: The user requesting the relations. |
| 506 | + room_id: The room the event belongs to. |
| 507 | + include: One of "all" or "participated" to indicate which threads should |
| 508 | + be returned. |
| 509 | + limit: Only fetch the most recent `limit` events. |
| 510 | + from_token: Fetch rows from the given token, or from the start if None. |
| 511 | +
|
| 512 | + Returns: |
| 513 | + The pagination chunk. |
| 514 | + """ |
| 515 | + |
| 516 | + user_id = requester.user.to_string() |
| 517 | + |
| 518 | + # TODO Properly handle a user leaving a room. |
| 519 | + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( |
| 520 | + room_id, requester, allow_departed_users=True |
| 521 | + ) |
| 522 | + |
| 523 | + # Note that ignored users are not passed into get_relations_for_event |
| 524 | + # below. Ignored users are handled in filter_events_for_client (and by |
| 525 | + # not passing them in here we should get a better cache hit rate). |
| 526 | + thread_roots, next_batch = await self._main_store.get_threads( |
| 527 | + room_id=room_id, limit=limit, from_token=from_token |
| 528 | + ) |
| 529 | + |
| 530 | + events = await self._main_store.get_events_as_list(thread_roots) |
| 531 | + |
| 532 | + if include == ThreadsListInclude.participated: |
| 533 | + # Pre-seed thread participation with whether the requester sent the event. |
| 534 | + participated = {event.event_id: event.sender == user_id for event in events} |
| 535 | + # For events the requester did not send, check the database for whether |
| 536 | + # the requester sent a threaded reply. |
| 537 | + participated.update( |
| 538 | + await self._main_store.get_threads_participated( |
| 539 | + [eid for eid, p in participated.items() if not p], |
| 540 | + user_id, |
| 541 | + ) |
| 542 | + ) |
| 543 | + |
| 544 | + # Limit the returned threads to those the user has participated in. |
| 545 | + events = [event for event in events if participated[event.event_id]] |
| 546 | + |
| 547 | + events = await filter_events_for_client( |
| 548 | + self._storage_controllers, |
| 549 | + user_id, |
| 550 | + events, |
| 551 | + is_peeking=(member_event_id is None), |
| 552 | + ) |
| 553 | + |
| 554 | + aggregations = await self.get_bundled_aggregations( |
| 555 | + events, requester.user.to_string() |
| 556 | + ) |
| 557 | + |
| 558 | + now = self._clock.time_msec() |
| 559 | + serialized_events = self._event_serializer.serialize_events( |
| 560 | + events, now, bundle_aggregations=aggregations |
| 561 | + ) |
| 562 | + |
| 563 | + return_value: JsonDict = {"chunk": serialized_events} |
| 564 | + |
| 565 | + if next_batch: |
| 566 | + return_value["next_batch"] = str(next_batch) |
| 567 | + |
| 568 | + return return_value |
0 commit comments