@@ -835,9 +835,9 @@ async def get_mutual_rooms_between_users(
835835
836836 return shared_room_ids or frozenset ()
837837
838- async def get_joined_users_from_state (
838+ async def get_joined_user_ids_from_state (
839839 self , room_id : str , state : StateMap [str ], state_entry : "_StateCacheEntry"
840- ) -> Dict [str , ProfileInfo ]:
840+ ) -> Set [str ]:
841841 state_group : Union [object , int ] = state_entry .state_group
842842 if not state_group :
843843 # If state_group is None it means it has yet to be assigned a
@@ -848,25 +848,25 @@ async def get_joined_users_from_state(
848848
849849 assert state_group is not None
850850 with Measure (self ._clock , "get_joined_users_from_state" ):
851- return await self ._get_joined_users_from_context (
851+ return await self ._get_joined_user_ids_from_context (
852852 room_id , state_group , state , context = state_entry
853853 )
854854
855855 @cached (num_args = 2 , iterable = True , max_entries = 100000 )
856- async def _get_joined_users_from_context (
856+ async def _get_joined_user_ids_from_context (
857857 self ,
858858 room_id : str ,
859859 state_group : Union [object , int ],
860860 current_state_ids : StateMap [str ],
861861 event : Optional [EventBase ] = None ,
862862 context : Optional ["_StateCacheEntry" ] = None ,
863- ) -> Dict [str , ProfileInfo ]:
863+ ) -> Set [str ]:
864864 # We don't use `state_group`, it's there so that we can cache based
865865 # on it. However, it's important that it's never None, since two current_states
866866 # with a state_group of None are likely to be different.
867867 assert state_group is not None
868868
869- users_in_room = {}
869+ users_in_room = set ()
870870 member_event_ids = [
871871 e_id
872872 for key , e_id in current_state_ids .items ()
@@ -879,19 +879,19 @@ async def _get_joined_users_from_context(
879879 # If we do then we can reuse that result and simply update it with
880880 # any membership changes in `delta_ids`
881881 if context .prev_group and context .delta_ids :
882- prev_res = self ._get_joined_users_from_context .cache .get_immediate (
882+ prev_res = self ._get_joined_user_ids_from_context .cache .get_immediate (
883883 (room_id , context .prev_group ), None
884884 )
885- if prev_res and isinstance (prev_res , dict ):
886- users_in_room = dict ( prev_res )
885+ if prev_res and isinstance (prev_res , set ):
886+ users_in_room = prev_res
887887 member_event_ids = [
888888 e_id
889889 for key , e_id in context .delta_ids .items ()
890890 if key [0 ] == EventTypes .Member
891891 ]
892892 for etype , state_key in context .delta_ids :
893893 if etype == EventTypes .Member :
894- users_in_room .pop (state_key , None )
894+ users_in_room .discard (state_key )
895895
896896 # We check if we have any of the member event ids in the event cache
897897 # before we ask the DB
@@ -908,42 +908,41 @@ async def _get_joined_users_from_context(
908908 ev_entry = event_map .get (event_id )
909909 if ev_entry and not ev_entry .event .rejected_reason :
910910 if ev_entry .event .membership == Membership .JOIN :
911- users_in_room [ev_entry .event .state_key ] = ProfileInfo (
912- display_name = ev_entry .event .content .get ("displayname" , None ),
913- avatar_url = ev_entry .event .content .get ("avatar_url" , None ),
914- )
911+ users_in_room .add (ev_entry .event .state_key )
915912 else :
916913 missing_member_event_ids .append (event_id )
917914
918915 if missing_member_event_ids :
919- event_to_memberships = await self ._get_joined_profiles_from_event_ids (
916+ event_to_memberships = await self ._get_user_ids_from_membership_event_ids (
920917 missing_member_event_ids
921918 )
922- users_in_room .update (row for row in event_to_memberships .values () if row )
919+ users_in_room .update (event_to_memberships .values ())
923920
924921 if event is not None and event .type == EventTypes .Member :
925922 if event .membership == Membership .JOIN :
926923 if event .event_id in member_event_ids :
927- users_in_room [event .state_key ] = ProfileInfo (
928- display_name = event .content .get ("displayname" , None ),
929- avatar_url = event .content .get ("avatar_url" , None ),
930- )
924+ users_in_room .add (event .state_key )
931925
932926 return users_in_room
933927
934- @cached (max_entries = 10000 )
935- def _get_joined_profile_from_event_id (
928+ @cached (
929+ max_entries = 10000 ,
930+ # This name matches the old function that has been replaced - the cache name
931+ # is kept here to maintain backwards compatibility.
932+ name = "_get_joined_profile_from_event_id" ,
933+ )
934+ def _get_user_id_from_membership_event_id (
936935 self , event_id : str
937936 ) -> Optional [Tuple [str , ProfileInfo ]]:
938937 raise NotImplementedError ()
939938
940939 @cachedList (
941- cached_method_name = "_get_joined_profile_from_event_id " ,
940+ cached_method_name = "_get_user_id_from_membership_event_id " ,
942941 list_name = "event_ids" ,
943942 )
944- async def _get_joined_profiles_from_event_ids (
943+ async def _get_user_ids_from_membership_event_ids (
945944 self , event_ids : Iterable [str ]
946- ) -> Dict [str , Optional [ Tuple [ str , ProfileInfo ]] ]:
945+ ) -> Dict [str , str ]:
947946 """For given set of member event_ids check if they point to a join
948947 event and if so return the associated user and profile info.
949948
@@ -958,21 +957,13 @@ async def _get_joined_profiles_from_event_ids(
958957 table = "room_memberships" ,
959958 column = "event_id" ,
960959 iterable = event_ids ,
961- retcols = ("user_id" , "display_name" , "avatar_url" , " event_id" ),
960+ retcols = ("user_id" , "event_id" ),
962961 keyvalues = {"membership" : Membership .JOIN },
963962 batch_size = 1000 ,
964- desc = "_get_joined_profiles_from_event_ids " ,
963+ desc = "_get_user_ids_from_membership_event_ids " ,
965964 )
966965
967- return {
968- row ["event_id" ]: (
969- row ["user_id" ],
970- ProfileInfo (
971- avatar_url = row ["avatar_url" ], display_name = row ["display_name" ]
972- ),
973- )
974- for row in rows
975- }
966+ return {row ["event_id" ]: row ["user_id" ] for row in rows }
976967
977968 @cached (max_entries = 10000 )
978969 async def is_host_joined (self , room_id : str , host : str ) -> bool :
@@ -1131,12 +1122,12 @@ async def _get_joined_hosts(
11311122 else :
11321123 # The cache doesn't match the state group or prev state group,
11331124 # so we calculate the result from first principles.
1134- joined_users = await self .get_joined_users_from_state (
1125+ joined_user_ids = await self .get_joined_user_ids_from_state (
11351126 room_id , state , state_entry
11361127 )
11371128
11381129 cache .hosts_to_joined_users = {}
1139- for user_id in joined_users :
1130+ for user_id in joined_user_ids :
11401131 host = intern_string (get_domain_from_id (user_id ))
11411132 cache .hosts_to_joined_users .setdefault (host , set ()).add (user_id )
11421133
0 commit comments