@@ -893,6 +893,43 @@ async def _check_host_room_membership(
893893
894894 return True
895895
896+ @cached (iterable = True , max_entries = 10000 )
897+ async def get_current_hosts_in_room (self , room_id : str ) -> Set [str ]:
898+ """Get current hosts in room based on current state."""
899+
900+ # First we check if we already have `get_users_in_room` in the cache, as
901+ # we can just calculate result from that
902+ users = self .get_users_in_room .cache .get_immediate (
903+ (room_id ,), None , update_metrics = False
904+ )
905+ if users is not None :
906+ return {get_domain_from_id (u ) for u in users }
907+
908+ if isinstance (self .database_engine , Sqlite3Engine ):
909+ # If we're using SQLite then let's just always use
910+ # `get_users_in_room` rather than funky SQL.
911+ users = await self .get_users_in_room (room_id )
912+ return {get_domain_from_id (u ) for u in users }
913+
914+ # For PostgreSQL we can use a regex to pull out the domains from the
915+ # joined users in `current_state_events` via regex.
916+
917+ def get_current_hosts_in_room_txn (txn : LoggingTransaction ) -> Set [str ]:
918+ sql = """
919+ SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
920+ FROM current_state_events
921+ WHERE
922+ type = 'm.room.member'
923+ AND membership = 'join'
924+ AND room_id = ?
925+ """
926+ txn .execute (sql , (room_id ,))
927+ return {d for d , in txn }
928+
929+ return await self .db_pool .runInteraction (
930+ "get_current_hosts_in_room" , get_current_hosts_in_room_txn
931+ )
932+
896933 async def get_joined_hosts (
897934 self , room_id : str , state_entry : "_StateCacheEntry"
898935 ) -> FrozenSet [str ]:
0 commit comments