Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
369 changes: 247 additions & 122 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import itertools
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple

from unpaddedbase64 import decode_base64, encode_base64

Expand Down Expand Up @@ -284,131 +284,42 @@ async def _search(
}
}

rank_map = {} # event_id -> rank of event
allowed_events = []
# Holds result of grouping by room, if applicable
room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable
sender_group: Dict[str, JsonDict] = {}

# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None

highlights = set()

count = None
sender_group: Optional[Dict[str, JsonDict]]

if order_by == "rank":
search_result = await self.store.search_msgs(room_ids, search_term, keys)

count = search_result["count"]

if search_result["highlights"]:
highlights.update(search_result["highlights"])

results = search_result["results"]

rank_map.update({r["event"].event_id: r["rank"] for r in results})

filtered_events = await search_filter.filter([r["event"] for r in results])

events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
(
count,
rank_map,
allowed_events,
room_groups,
highlights,
sender_group,
) = await self._search_by_rank(
user, room_ids, search_term, keys, search_filter
)

events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = events[: search_filter.limit]

for e in allowed_events:
rm = room_groups.setdefault(
e.room_id, {"results": [], "order": rank_map[e.event_id]}
)
rm["results"].append(e.event_id)

s = sender_group.setdefault(
e.sender, {"results": [], "order": rank_map[e.event_id]}
)
s["results"].append(e.event_id)
# Unused return values for rank search.
global_next_batch = None

elif order_by == "recent":
room_events: List[EventBase] = []
i = 0

pagination_token = batch_token

# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit and i < 5:
i += 1
search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
search_filter.limit * 2,
pagination_token=pagination_token,
)

if search_result["highlights"]:
highlights.update(search_result["highlights"])

count = search_result["count"]

results = search_result["results"]

results_map = {r["event"].event_id: r for r in results}

rank_map.update({r["event"].event_id: r["rank"] for r in results})

filtered_events = await search_filter.filter(
[r["event"] for r in results]
)

events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)

room_events.extend(events)
room_events = room_events[: search_filter.limit]

if len(results) < search_filter.limit * 2:
pagination_token = None
break
else:
pagination_token = results[-1]["pagination_token"]

for event in room_events:
group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)

if room_events and len(room_events) >= search_filter.limit:
last_event_id = room_events[-1].event_id
pagination_token = results_map[last_event_id]["pagination_token"]

# We want to respect the given batch group and group keys so
# that if people blindly use the top level `next_batch` token
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
global_next_batch = encode_base64(
(
"%s\n%s\n%s"
% (batch_group, batch_group_key, pagination_token)
).encode("ascii")
)
else:
global_next_batch = encode_base64(
("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
)

for room_id, group in room_groups.items():
group["next_batch"] = encode_base64(
("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
"ascii"
)
)

allowed_events.extend(room_events)
(
count,
rank_map,
allowed_events,
room_groups,
highlights,
global_next_batch,
) = await self._search_by_recent(
user,
room_ids,
search_term,
keys,
search_filter,
batch_group,
batch_group_key,
batch_token,
)
# Unused return values for recent search.
sender_group = None

else:
# We should never get here due to the guard earlier.
Expand Down Expand Up @@ -538,7 +449,7 @@ async def _search(
}
)

rooms_cat_res = {
rooms_cat_res: JsonDict = {
"results": results,
"count": count,
"highlights": list(highlights),
Expand All @@ -563,3 +474,217 @@ async def _search(
rooms_cat_res["next_batch"] = global_next_batch

return {"search_categories": {"room_events": rooms_cat_res}}

async def _search_by_rank(
self,
user: UserID,
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
search_filter: Filter,
) -> Tuple[
int,
Dict[str, int],
List[EventBase],
Dict[str, JsonDict],
Set[str],
Dict[str, JsonDict],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice if we could make the return types an attr.s class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, that should be easy enough! 👍

]:
"""
Performs a full text search for a user ordering by rank.

Args:
user: The user performing the search.
room_ids: List of room ids to search in
search_term: Search term to search for
keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
search_filter: The event filter to use.

Returns:
A tuple of:
The count of results.
A mapping of event ID to the rank of that event.
A list of the resulting events.
A map of room ID to results.
A set of event IDs to highlight.
A map of sender ID to results.
"""
rank_map = {} # event_id -> rank of event
# Holds result of grouping by room, if applicable
room_groups: Dict[str, JsonDict] = {}
# Holds result of grouping by sender, if applicable
sender_group: Dict[str, JsonDict] = {}

highlights = set()

search_result = await self.store.search_msgs(room_ids, search_term, keys)

count = search_result["count"]

if search_result["highlights"]:
highlights.update(search_result["highlights"])

results = search_result["results"]

results_map = {r["event"].event_id: r for r in results}

rank_map.update({r["event"].event_id: r["rank"] for r in results})

filtered_events = await search_filter.filter([r["event"] for r in results])

events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)

events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = events[: search_filter.limit]

for e in allowed_events:
rm = room_groups.setdefault(
e.room_id, {"results": [], "order": rank_map[e.event_id]}
)
rm["results"].append(e.event_id)

s = sender_group.setdefault(
e.sender, {"results": [], "order": rank_map[e.event_id]}
)
s["results"].append(e.event_id)

return count, rank_map, allowed_events, room_groups, highlights, sender_group

async def _search_by_recent(
self,
user: UserID,
room_ids: Collection[str],
search_term: str,
keys: Iterable[str],
search_filter: Filter,
batch_group: Optional[str],
batch_group_key: Optional[str],
batch_token: Optional[str],
) -> Tuple[
int,
Dict[str, int],
List[EventBase],
Dict[str, JsonDict],
Set[str],
Optional[str],
]:
"""
Performs a full text search for a user ordering by recent.

Args:
user: The user performing the search.
room_ids: List of room ids to search in
search_term: Search term to search for
keys: List of keys to search in, currently supports
"content.body", "content.name", "content.topic"
search_filter: The event filter to use.
batch_group: Pagination information.
batch_group_key: Pagination information.
batch_token: Pagination information.

Returns:
A tuple of:
The count of results.
A mapping of event ID to the rank of that event.
A list of the resulting events.
A map of room ID to results.
A set of event IDs to highlight.
Optionally, a pagination token.
"""
rank_map = {} # event_id -> rank of event
allowed_events: List[EventBase] = []
# Holds result of grouping by room, if applicable
room_groups: Dict[str, JsonDict] = {}

# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None

highlights = set()

room_events: List[EventBase] = []
i = 0

pagination_token = batch_token

# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit and i < 5:
i += 1
search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
search_filter.limit * 2,
pagination_token=pagination_token,
)

if search_result["highlights"]:
highlights.update(search_result["highlights"])

count = search_result["count"]

results = search_result["results"]

results_map = {r["event"].event_id: r for r in results}

rank_map.update({r["event"].event_id: r["rank"] for r in results})

filtered_events = await search_filter.filter([r["event"] for r in results])

events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)

room_events.extend(events)
room_events = room_events[: search_filter.limit]

if len(results) < search_filter.limit * 2:
pagination_token = None
break
else:
pagination_token = results[-1]["pagination_token"]

for event in room_events:
group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)

if room_events and len(room_events) >= search_filter.limit:
last_event_id = room_events[-1].event_id
pagination_token = results_map[last_event_id]["pagination_token"]

# We want to respect the given batch group and group keys so
# that if people blindly use the top level `next_batch` token
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
global_next_batch = encode_base64(
(
"%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token)
).encode("ascii")
)
else:
global_next_batch = encode_base64(
("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
)

for room_id, group in room_groups.items():
group["next_batch"] = encode_base64(
("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
"ascii"
)
)

allowed_events.extend(room_events)

return (
count,
rank_map,
allowed_events,
room_groups,
highlights,
global_next_batch,
)
Loading