1414
1515
1616import json
17- from typing import Dict
17+ from typing import Dict , List , Set
1818from unittest .mock import ANY , Mock , call
1919
20- from twisted .internet import defer
2120from twisted .test .proto_helpers import MemoryReactor
2221from twisted .web .resource import Resource
2322
2423from synapse .api .constants import EduTypes
2524from synapse .api .errors import AuthError
2625from synapse .federation .transport .server import TransportLayerServer
26+ from synapse .handlers .typing import TypingWriterHandler
2727from synapse .server import HomeServer
2828from synapse .types import JsonDict , Requester , UserID , create_requester
2929from synapse .util import Clock
3030
3131from tests import unittest
32+ from tests .server import ThreadedMemoryReactorClock
3233from tests .test_utils import make_awaitable
3334from tests .unittest import override_config
3435
@@ -62,7 +63,11 @@ def _make_edu_transaction_json(edu_type: str, content: JsonDict) -> bytes:
6263
6364
6465class TypingNotificationsTestCase (unittest .HomeserverTestCase ):
65- def make_homeserver (self , reactor : MemoryReactor , clock : Clock ) -> HomeServer :
66+ def make_homeserver (
67+ self ,
68+ reactor : ThreadedMemoryReactorClock ,
69+ clock : Clock ,
70+ ) -> HomeServer :
6671 # we mock out the keyring so as to skip the authentication check on the
6772 # federation API call.
6873 mock_keyring = Mock (spec = ["verify_json_for_server" ])
@@ -75,8 +80,9 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
7580 # the tests assume that we are starting at unix time 1000
7681 reactor .pump ((1000 ,))
7782
83+ self .mock_hs_notifier = Mock ()
7884 hs = self .setup_test_homeserver (
79- notifier = Mock () ,
85+ notifier = self . mock_hs_notifier ,
8086 federation_http_client = mock_federation_client ,
8187 keyring = mock_keyring ,
8288 replication_streams = {},
@@ -90,32 +96,38 @@ def create_resource_dict(self) -> Dict[str, Resource]:
9096 return d
9197
9298 def prepare (self , reactor : MemoryReactor , clock : Clock , hs : HomeServer ) -> None :
93- mock_notifier = hs .get_notifier ()
94- self .on_new_event = mock_notifier .on_new_event
99+ self .on_new_event = self .mock_hs_notifier .on_new_event
95100
96- self .handler = hs .get_typing_handler ()
101+ # hs.get_typing_handler will return a TypingWriterHandler when calling it
102+ # from the main process, and a FollowerTypingHandler on workers.
103+ # We rely on methods only available on the former, so assert we have the
104+ # correct type here. We have to assign self.handler after the assert,
105+ # otherwise mypy will treat it as a FollowerTypingHandler
106+ handler = hs .get_typing_handler ()
107+ assert isinstance (handler , TypingWriterHandler )
108+ self .handler = handler
97109
98110 self .event_source = hs .get_event_sources ().sources .typing
99111
100112 self .datastore = hs .get_datastores ().main
113+
101114 self .datastore .get_destination_retry_timings = Mock (
102115 return_value = make_awaitable (None )
103116 )
104117
105- self .datastore .get_device_updates_by_remote = Mock (
118+ self .datastore .get_device_updates_by_remote = Mock ( # type: ignore[assignment]
106119 return_value = make_awaitable ((0 , []))
107120 )
108121
109- self .datastore .get_destination_last_successful_stream_ordering = Mock (
122+ self .datastore .get_destination_last_successful_stream_ordering = Mock ( # type: ignore[assignment]
110123 return_value = make_awaitable (None )
111124 )
112125
113- def get_received_txn_response (* args ):
114- return defer .succeed (None )
115-
116- self .datastore .get_received_txn_response = get_received_txn_response
126+ self .datastore .get_received_txn_response = Mock ( # type: ignore[assignment]
127+ return_value = make_awaitable (None )
128+ )
117129
118- self .room_members = []
130+ self .room_members : List [ UserID ] = []
119131
120132 async def check_user_in_room (room_id : str , requester : Requester ) -> None :
121133 if requester .user .to_string () not in [
@@ -124,47 +136,54 @@ async def check_user_in_room(room_id: str, requester: Requester) -> None:
124136 raise AuthError (401 , "User is not in the room" )
125137 return None
126138
127- hs .get_auth ().check_user_in_room = check_user_in_room
139+ hs .get_auth ().check_user_in_room = Mock ( # type: ignore[assignment]
140+ side_effect = check_user_in_room
141+ )
128142
129143 async def check_host_in_room (room_id : str , server_name : str ) -> bool :
130144 return room_id == ROOM_ID
131145
132- hs .get_event_auth_handler ().is_host_in_room = check_host_in_room
146+ hs .get_event_auth_handler ().is_host_in_room = Mock ( # type: ignore[assignment]
147+ side_effect = check_host_in_room
148+ )
133149
134- async def get_current_hosts_in_room (room_id : str ):
150+ async def get_current_hosts_in_room (room_id : str ) -> Set [ str ] :
135151 return {member .domain for member in self .room_members }
136152
137- hs .get_storage_controllers ().state .get_current_hosts_in_room = (
138- get_current_hosts_in_room
153+ hs .get_storage_controllers ().state .get_current_hosts_in_room = Mock ( # type: ignore[assignment]
154+ side_effect = get_current_hosts_in_room
139155 )
140156
141- hs .get_storage_controllers ().state .get_current_hosts_in_room_or_partial_state_approximation = (
142- get_current_hosts_in_room
157+ hs .get_storage_controllers ().state .get_current_hosts_in_room_or_partial_state_approximation = Mock ( # type: ignore[assignment]
158+ side_effect = get_current_hosts_in_room
143159 )
144160
145- async def get_users_in_room (room_id : str ):
161+ async def get_users_in_room (room_id : str ) -> Set [ str ] :
146162 return {str (u ) for u in self .room_members }
147163
148- self .datastore .get_users_in_room = get_users_in_room
164+ self .datastore .get_users_in_room = Mock ( side_effect = get_users_in_room )
149165
150- self .datastore .get_user_directory_stream_pos = Mock (
166+ self .datastore .get_user_directory_stream_pos = Mock ( # type: ignore[assignment]
151167 side_effect = (
152- # we deliberately return a non-None stream pos to avoid doing an initial_spam
168+ # we deliberately return a non-None stream pos to avoid
169+ # doing an initial_sync
153170 lambda : make_awaitable (1 )
154171 )
155172 )
156173
157- self .datastore .get_partial_current_state_deltas = Mock (return_value = (0 , None ))
174+ self .datastore .get_partial_current_state_deltas = Mock (return_value = (0 , None )) # type: ignore[assignment]
158175
159- self .datastore .get_to_device_stream_token = lambda : 0
160- self .datastore .get_new_device_msgs_for_remote = (
161- lambda * args , ** kargs : make_awaitable (([], 0 ))
176+ self .datastore .get_to_device_stream_token = Mock ( # type: ignore[assignment]
177+ side_effect = lambda : 0
178+ )
179+ self .datastore .get_new_device_msgs_for_remote = Mock ( # type: ignore[assignment]
180+ side_effect = lambda * args , ** kargs : make_awaitable (([], 0 ))
162181 )
163- self .datastore .delete_device_msgs_for_remote = (
164- lambda * args , ** kargs : make_awaitable (None )
182+ self .datastore .delete_device_msgs_for_remote = Mock ( # type: ignore[assignment]
183+ side_effect = lambda * args , ** kargs : make_awaitable (None )
165184 )
166- self .datastore .set_received_txn_response = (
167- lambda * args , ** kwargs : make_awaitable (None )
185+ self .datastore .set_received_txn_response = Mock ( # type: ignore[assignment]
186+ side_effect = lambda * args , ** kwargs : make_awaitable (None )
168187 )
169188
170189 def test_started_typing_local (self ) -> None :
@@ -186,7 +205,7 @@ def test_started_typing_local(self) -> None:
186205 self .assertEqual (self .event_source .get_current_key (), 1 )
187206 events = self .get_success (
188207 self .event_source .get_new_events (
189- user = U_APPLE , from_key = 0 , limit = None , room_ids = [ROOM_ID ], is_guest = False
208+ user = U_APPLE , from_key = 0 , limit = 0 , room_ids = [ROOM_ID ], is_guest = False
190209 )
191210 )
192211 self .assertEqual (
@@ -257,7 +276,7 @@ def test_started_typing_remote_recv(self) -> None:
257276 self .assertEqual (self .event_source .get_current_key (), 1 )
258277 events = self .get_success (
259278 self .event_source .get_new_events (
260- user = U_APPLE , from_key = 0 , limit = None , room_ids = [ROOM_ID ], is_guest = False
279+ user = U_APPLE , from_key = 0 , limit = 0 , room_ids = [ROOM_ID ], is_guest = False
261280 )
262281 )
263282 self .assertEqual (
@@ -298,7 +317,7 @@ def test_started_typing_remote_recv_not_in_room(self) -> None:
298317 self .event_source .get_new_events (
299318 user = U_APPLE ,
300319 from_key = 0 ,
301- limit = None ,
320+ limit = 0 ,
302321 room_ids = [OTHER_ROOM_ID ],
303322 is_guest = False ,
304323 )
@@ -351,7 +370,7 @@ def test_stopped_typing(self) -> None:
351370 self .assertEqual (self .event_source .get_current_key (), 1 )
352371 events = self .get_success (
353372 self .event_source .get_new_events (
354- user = U_APPLE , from_key = 0 , limit = None , room_ids = [ROOM_ID ], is_guest = False
373+ user = U_APPLE , from_key = 0 , limit = 0 , room_ids = [ROOM_ID ], is_guest = False
355374 )
356375 )
357376 self .assertEqual (
@@ -387,7 +406,7 @@ def test_typing_timeout(self) -> None:
387406 self .event_source .get_new_events (
388407 user = U_APPLE ,
389408 from_key = 0 ,
390- limit = None ,
409+ limit = 0 ,
391410 room_ids = [ROOM_ID ],
392411 is_guest = False ,
393412 )
@@ -412,7 +431,7 @@ def test_typing_timeout(self) -> None:
412431 self .event_source .get_new_events (
413432 user = U_APPLE ,
414433 from_key = 1 ,
415- limit = None ,
434+ limit = 0 ,
416435 room_ids = [ROOM_ID ],
417436 is_guest = False ,
418437 )
@@ -447,7 +466,7 @@ def test_typing_timeout(self) -> None:
447466 self .event_source .get_new_events (
448467 user = U_APPLE ,
449468 from_key = 0 ,
450- limit = None ,
469+ limit = 0 ,
451470 room_ids = [ROOM_ID ],
452471 is_guest = False ,
453472 )
0 commit comments