Skip to content

Commit 919c362

Browse files
Remove destinations from sending if not whitelisted (element-hq#18484)
Co-authored-by: Andrew Morgan <[email protected]>
1 parent 82189cb commit 919c362

File tree

3 files changed

+71
-8
lines changed

3 files changed

+71
-8
lines changed

changelog.d/18484.bugfix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Remove destinations from sending if not whitelisted.

synapse/config/federation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,21 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None:
9494
2**62,
9595
)
9696

97+
def is_domain_allowed_according_to_federation_whitelist(self, domain: str) -> bool:
98+
"""
99+
Returns whether a domain is allowed according to the federation whitelist. If a
100+
federation whitelist is not set, all domains are allowed.
101+
102+
Args:
103+
domain: The domain to test.
104+
105+
Returns:
106+
True if the domain is allowed or if a whitelist is not set, False otherwise.
107+
"""
108+
if self.federation_domain_whitelist is None:
109+
return True
110+
111+
return domain in self.federation_domain_whitelist
112+
97113

98114
_METRICS_FOR_DOMAINS_SCHEMA = {"type": "array", "items": {"type": "string"}}

synapse/federation/sender/__init__.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ async def _handle(self) -> None:
342342
destination, _ = self.queue.popitem(last=False)
343343

344344
queue = self.sender._get_per_destination_queue(destination)
345+
if queue is None:
346+
continue
345347

346348
if not queue._new_data_to_send:
347349
# The per destination queue has already been woken up.
@@ -436,12 +438,23 @@ def __init__(self, hs: "HomeServer"):
436438
self._wake_destinations_needing_catchup,
437439
)
438440

439-
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
441+
def _get_per_destination_queue(
442+
self, destination: str
443+
) -> Optional[PerDestinationQueue]:
440444
"""Get or create a PerDestinationQueue for the given destination
441445
442446
Args:
443447
destination: server_name of remote server
448+
449+
Returns:
450+
None if the destination is not allowed by the federation whitelist.
451+
Otherwise a PerDestinationQueue for this destination.
444452
"""
453+
if not self.hs.config.federation.is_domain_allowed_according_to_federation_whitelist(
454+
destination
455+
):
456+
return None
457+
445458
queue = self._per_destination_queues.get(destination)
446459
if not queue:
447460
queue = PerDestinationQueue(self.hs, self._transaction_manager, destination)
@@ -718,6 +731,16 @@ async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
718731
# track the fact that we have a PDU for these destinations,
719732
# to allow us to perform catch-up later on if the remote is unreachable
720733
# for a while.
734+
# Filter out any destinations not present in the federation_domain_whitelist, if
735+
# the whitelist exists. These destinations should not be sent to so let's not
736+
# waste time or space keeping track of events destined for them.
737+
destinations = [
738+
d
739+
for d in destinations
740+
if self.hs.config.federation.is_domain_allowed_according_to_federation_whitelist(
741+
d
742+
)
743+
]
721744
await self.store.store_destination_rooms_entries(
722745
destinations,
723746
pdu.room_id,
@@ -732,7 +755,12 @@ async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
732755
)
733756

734757
for destination in destinations:
735-
self._get_per_destination_queue(destination).send_pdu(pdu)
758+
queue = self._get_per_destination_queue(destination)
759+
# We expect `queue` to not be None as we already filtered out
760+
# non-whitelisted destinations above.
761+
assert queue is not None
762+
763+
queue.send_pdu(pdu)
736764

737765
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
738766
"""Send a RR to any other servers in the room
@@ -841,12 +869,16 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
841869
for domain in immediate_domains:
842870
# Add to destination queue and wake the destination up
843871
queue = self._get_per_destination_queue(domain)
872+
if queue is None:
873+
continue
844874
queue.queue_read_receipt(receipt)
845875
queue.attempt_new_transaction()
846876

847877
for domain in delay_domains:
848878
# Add to destination queue...
849879
queue = self._get_per_destination_queue(domain)
880+
if queue is None:
881+
continue
850882
queue.queue_read_receipt(receipt)
851883

852884
# ... and schedule the destination to be woken up.
@@ -882,9 +914,10 @@ async def send_presence_to_destinations(
882914
if self.is_mine_server_name(destination):
883915
continue
884916

885-
self._get_per_destination_queue(destination).send_presence(
886-
states, start_loop=False
887-
)
917+
queue = self._get_per_destination_queue(destination)
918+
if queue is None:
919+
continue
920+
queue.send_presence(states, start_loop=False)
888921

889922
self._destination_wakeup_queue.add_to_queue(destination)
890923

@@ -934,6 +967,8 @@ def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
934967
return
935968

936969
queue = self._get_per_destination_queue(edu.destination)
970+
if queue is None:
971+
return
937972
if key:
938973
queue.send_keyed_edu(edu, key)
939974
else:
@@ -958,9 +993,15 @@ async def send_device_messages(
958993

959994
for destination in destinations:
960995
if immediate:
961-
self._get_per_destination_queue(destination).attempt_new_transaction()
996+
queue = self._get_per_destination_queue(destination)
997+
if queue is None:
998+
continue
999+
queue.attempt_new_transaction()
9621000
else:
963-
self._get_per_destination_queue(destination).mark_new_data()
1001+
queue = self._get_per_destination_queue(destination)
1002+
if queue is None:
1003+
continue
1004+
queue.mark_new_data()
9641005
self._destination_wakeup_queue.add_to_queue(destination)
9651006

9661007
def wake_destination(self, destination: str) -> None:
@@ -979,7 +1020,9 @@ def wake_destination(self, destination: str) -> None:
9791020
):
9801021
return
9811022

982-
self._get_per_destination_queue(destination).attempt_new_transaction()
1023+
queue = self._get_per_destination_queue(destination)
1024+
if queue is not None:
1025+
queue.attempt_new_transaction()
9831026

9841027
@staticmethod
9851028
def get_current_token() -> int:
@@ -1024,6 +1067,9 @@ async def _wake_destinations_needing_catchup(self) -> None:
10241067
d
10251068
for d in destinations_to_wake
10261069
if self._federation_shard_config.should_handle(self._instance_name, d)
1070+
and self.hs.config.federation.is_domain_allowed_according_to_federation_whitelist(
1071+
d
1072+
)
10271073
]
10281074

10291075
for destination in destinations_to_wake:

0 commit comments

Comments
 (0)