1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from unittest .mock import Mock
15+ from unittest .mock import AsyncMock , Mock
1616
1717import pymacaroons
1818
3535from synapse .util import Clock
3636
3737from tests import unittest
38- from tests .test_utils import simple_async_mock
3938from tests .unittest import override_config
4039from tests .utils import mock_getRawHeaders
4140
@@ -60,16 +59,16 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
6059 # this is overridden for the appservice tests
6160 self .store .get_app_service_by_token = Mock (return_value = None )
6261
63- self .store .insert_client_ip = simple_async_mock ( None )
64- self .store .is_support_user = simple_async_mock ( False )
62+ self .store .insert_client_ip = AsyncMock ( return_value = None )
63+ self .store .is_support_user = AsyncMock ( return_value = False )
6564
6665 def test_get_user_by_req_user_valid_token (self ) -> None :
6766 user_info = TokenLookupResult (
6867 user_id = self .test_user , token_id = 5 , device_id = "device"
6968 )
70- self .store .get_user_by_access_token = simple_async_mock ( user_info )
71- self .store .mark_access_token_as_used = simple_async_mock ( None )
72- self .store .get_user_locked_status = simple_async_mock ( False )
69+ self .store .get_user_by_access_token = AsyncMock ( return_value = user_info )
70+ self .store .mark_access_token_as_used = AsyncMock ( return_value = None )
71+ self .store .get_user_locked_status = AsyncMock ( return_value = False )
7372
7473 request = Mock (args = {})
7574 request .args [b"access_token" ] = [self .test_token ]
@@ -78,7 +77,7 @@ def test_get_user_by_req_user_valid_token(self) -> None:
7877 self .assertEqual (requester .user .to_string (), self .test_user )
7978
8079 def test_get_user_by_req_user_bad_token (self ) -> None :
81- self .store .get_user_by_access_token = simple_async_mock ( None )
80+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
8281
8382 request = Mock (args = {})
8483 request .args [b"access_token" ] = [self .test_token ]
@@ -91,7 +90,7 @@ def test_get_user_by_req_user_bad_token(self) -> None:
9190
9291 def test_get_user_by_req_user_missing_token (self ) -> None :
9392 user_info = TokenLookupResult (user_id = self .test_user , token_id = 5 )
94- self .store .get_user_by_access_token = simple_async_mock ( user_info )
93+ self .store .get_user_by_access_token = AsyncMock ( return_value = user_info )
9594
9695 request = Mock (args = {})
9796 request .requestHeaders .getRawHeaders = mock_getRawHeaders ()
@@ -106,7 +105,7 @@ def test_get_user_by_req_appservice_valid_token(self) -> None:
106105 token = "foobar" , url = "a_url" , sender = self .test_user , ip_range_whitelist = None
107106 )
108107 self .store .get_app_service_by_token = Mock (return_value = app_service )
109- self .store .get_user_by_access_token = simple_async_mock ( None )
108+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
110109
111110 request = Mock (args = {})
112111 request .getClientAddress .return_value .host = "127.0.0.1"
@@ -125,7 +124,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self) -> None:
125124 ip_range_whitelist = IPSet (["192.168/16" ]),
126125 )
127126 self .store .get_app_service_by_token = Mock (return_value = app_service )
128- self .store .get_user_by_access_token = simple_async_mock ( None )
127+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
129128
130129 request = Mock (args = {})
131130 request .getClientAddress .return_value .host = "192.168.10.10"
@@ -144,7 +143,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
144143 ip_range_whitelist = IPSet (["192.168/16" ]),
145144 )
146145 self .store .get_app_service_by_token = Mock (return_value = app_service )
147- self .store .get_user_by_access_token = simple_async_mock ( None )
146+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
148147
149148 request = Mock (args = {})
150149 request .getClientAddress .return_value .host = "131.111.8.42"
@@ -158,7 +157,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self) -> None:
158157
159158 def test_get_user_by_req_appservice_bad_token (self ) -> None :
160159 self .store .get_app_service_by_token = Mock (return_value = None )
161- self .store .get_user_by_access_token = simple_async_mock ( None )
160+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
162161
163162 request = Mock (args = {})
164163 request .args [b"access_token" ] = [self .test_token ]
@@ -172,7 +171,7 @@ def test_get_user_by_req_appservice_bad_token(self) -> None:
172171 def test_get_user_by_req_appservice_missing_token (self ) -> None :
173172 app_service = Mock (token = "foobar" , url = "a_url" , sender = self .test_user )
174173 self .store .get_app_service_by_token = Mock (return_value = app_service )
175- self .store .get_user_by_access_token = simple_async_mock ( None )
174+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
176175
177176 request = Mock (args = {})
178177 request .requestHeaders .getRawHeaders = mock_getRawHeaders ()
@@ -190,8 +189,8 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self) -> None:
190189 app_service .is_interested_in_user = Mock (return_value = True )
191190 self .store .get_app_service_by_token = Mock (return_value = app_service )
192191 # This just needs to return a truth-y value.
193- self .store .get_user_by_id = simple_async_mock ( {"is_guest" : False })
194- self .store .get_user_by_access_token = simple_async_mock ( None )
192+ self .store .get_user_by_id = AsyncMock ( return_value = {"is_guest" : False })
193+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
195194
196195 request = Mock (args = {})
197196 request .getClientAddress .return_value .host = "127.0.0.1"
@@ -210,7 +209,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self) -> None:
210209 )
211210 app_service .is_interested_in_user = Mock (return_value = False )
212211 self .store .get_app_service_by_token = Mock (return_value = app_service )
213- self .store .get_user_by_access_token = simple_async_mock ( None )
212+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
214213
215214 request = Mock (args = {})
216215 request .getClientAddress .return_value .host = "127.0.0.1"
@@ -234,10 +233,10 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self) -> None:
234233 app_service .is_interested_in_user = Mock (return_value = True )
235234 self .store .get_app_service_by_token = Mock (return_value = app_service )
236235 # This just needs to return a truth-y value.
237- self .store .get_user_by_id = simple_async_mock ( {"is_guest" : False })
238- self .store .get_user_by_access_token = simple_async_mock ( None )
236+ self .store .get_user_by_id = AsyncMock ( return_value = {"is_guest" : False })
237+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
239238 # This also needs to just return a truth-y value
240- self .store .get_device = simple_async_mock ( {"hidden" : False })
239+ self .store .get_device = AsyncMock ( return_value = {"hidden" : False })
241240
242241 request = Mock (args = {})
243242 request .getClientAddress .return_value .host = "127.0.0.1"
@@ -266,10 +265,10 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
266265 app_service .is_interested_in_user = Mock (return_value = True )
267266 self .store .get_app_service_by_token = Mock (return_value = app_service )
268267 # This just needs to return a truth-y value.
269- self .store .get_user_by_id = simple_async_mock ( {"is_guest" : False })
270- self .store .get_user_by_access_token = simple_async_mock ( None )
268+ self .store .get_user_by_id = AsyncMock ( return_value = {"is_guest" : False })
269+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
271270 # This also needs to just return a falsey value
272- self .store .get_device = simple_async_mock ( None )
271+ self .store .get_device = AsyncMock ( return_value = None )
273272
274273 request = Mock (args = {})
275274 request .getClientAddress .return_value .host = "127.0.0.1"
@@ -283,18 +282,18 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self) -> None:
283282 self .assertEqual (failure .value .errcode , Codes .EXCLUSIVE )
284283
285284 def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau (self ) -> None :
286- self .store .get_user_by_access_token = simple_async_mock (
287- TokenLookupResult (
285+ self .store .get_user_by_access_token = AsyncMock (
286+ return_value = TokenLookupResult (
288287 user_id = "@baldrick:matrix.org" ,
289288 device_id = "device" ,
290289 token_id = 5 ,
291290 token_owner = "@admin:matrix.org" ,
292291 token_used = True ,
293292 )
294293 )
295- self .store .insert_client_ip = simple_async_mock ( None )
296- self .store .mark_access_token_as_used = simple_async_mock ( None )
297- self .store .get_user_locked_status = simple_async_mock ( False )
294+ self .store .insert_client_ip = AsyncMock ( return_value = None )
295+ self .store .mark_access_token_as_used = AsyncMock ( return_value = None )
296+ self .store .get_user_locked_status = AsyncMock ( return_value = False )
298297 request = Mock (args = {})
299298 request .getClientAddress .return_value .host = "127.0.0.1"
300299 request .args [b"access_token" ] = [self .test_token ]
@@ -304,18 +303,18 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> Non
304303
305304 def test_get_user_by_req__puppeted_token__tracking_puppeted_mau (self ) -> None :
306305 self .auth ._track_puppeted_user_ips = True
307- self .store .get_user_by_access_token = simple_async_mock (
308- TokenLookupResult (
306+ self .store .get_user_by_access_token = AsyncMock (
307+ return_value = TokenLookupResult (
309308 user_id = "@baldrick:matrix.org" ,
310309 device_id = "device" ,
311310 token_id = 5 ,
312311 token_owner = "@admin:matrix.org" ,
313312 token_used = True ,
314313 )
315314 )
316- self .store .get_user_locked_status = simple_async_mock ( False )
317- self .store .insert_client_ip = simple_async_mock ( None )
318- self .store .mark_access_token_as_used = simple_async_mock ( None )
315+ self .store .get_user_locked_status = AsyncMock ( return_value = False )
316+ self .store .insert_client_ip = AsyncMock ( return_value = None )
317+ self .store .mark_access_token_as_used = AsyncMock ( return_value = None )
319318 request = Mock (args = {})
320319 request .getClientAddress .return_value .host = "127.0.0.1"
321320 request .args [b"access_token" ] = [self .test_token ]
@@ -324,7 +323,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
324323 self .assertEqual (self .store .insert_client_ip .call_count , 2 )
325324
326325 def test_get_user_from_macaroon (self ) -> None :
327- self .store .get_user_by_access_token = simple_async_mock ( None )
326+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
328327
329328 user_id = "@baldrick:matrix.org"
330329 macaroon = pymacaroons .Macaroon (
@@ -342,8 +341,8 @@ def test_get_user_from_macaroon(self) -> None:
342341 )
343342
344343 def test_get_guest_user_from_macaroon (self ) -> None :
345- self .store .get_user_by_id = simple_async_mock ( {"is_guest" : True })
346- self .store .get_user_by_access_token = simple_async_mock ( None )
344+ self .store .get_user_by_id = AsyncMock ( return_value = {"is_guest" : True })
345+ self .store .get_user_by_access_token = AsyncMock ( return_value = None )
347346
348347 user_id = "@baldrick:matrix.org"
349348 macaroon = pymacaroons .Macaroon (
@@ -373,7 +372,7 @@ def test_blocking_mau(self) -> None:
373372
374373 self .auth_blocking ._limit_usage_by_mau = True
375374
376- self .store .get_monthly_active_count = simple_async_mock ( lots_of_users )
375+ self .store .get_monthly_active_count = AsyncMock ( return_value = lots_of_users )
377376
378377 e = self .get_failure (
379378 self .auth_blocking .check_auth_blocking (), ResourceLimitError
@@ -383,25 +382,27 @@ def test_blocking_mau(self) -> None:
383382 self .assertEqual (e .value .code , 403 )
384383
385384 # Ensure does not throw an error
386- self .store .get_monthly_active_count = simple_async_mock (small_number_of_users )
385+ self .store .get_monthly_active_count = AsyncMock (
386+ return_value = small_number_of_users
387+ )
387388 self .get_success (self .auth_blocking .check_auth_blocking ())
388389
389390 def test_blocking_mau__depending_on_user_type (self ) -> None :
390391 self .auth_blocking ._max_mau_value = 50
391392 self .auth_blocking ._limit_usage_by_mau = True
392393
393- self .store .get_monthly_active_count = simple_async_mock ( 100 )
394+ self .store .get_monthly_active_count = AsyncMock ( return_value = 100 )
394395 # Support users allowed
395396 self .get_success (
396397 self .auth_blocking .check_auth_blocking (user_type = UserTypes .SUPPORT )
397398 )
398- self .store .get_monthly_active_count = simple_async_mock ( 100 )
399+ self .store .get_monthly_active_count = AsyncMock ( return_value = 100 )
399400 # Bots not allowed
400401 self .get_failure (
401402 self .auth_blocking .check_auth_blocking (user_type = UserTypes .BOT ),
402403 ResourceLimitError ,
403404 )
404- self .store .get_monthly_active_count = simple_async_mock ( 100 )
405+ self .store .get_monthly_active_count = AsyncMock ( return_value = 100 )
405406 # Real users not allowed
406407 self .get_failure (self .auth_blocking .check_auth_blocking (), ResourceLimitError )
407408
@@ -412,9 +413,9 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(
412413 self .auth_blocking ._limit_usage_by_mau = True
413414 self .auth_blocking ._track_appservice_user_ips = False
414415
415- self .store .get_monthly_active_count = simple_async_mock ( 100 )
416- self .store .user_last_seen_monthly_active = simple_async_mock ( )
417- self .store .is_trial_user = simple_async_mock ( )
416+ self .store .get_monthly_active_count = AsyncMock ( return_value = 100 )
417+ self .store .user_last_seen_monthly_active = AsyncMock ( return_value = None )
418+ self .store .is_trial_user = AsyncMock ( return_value = False )
418419
419420 appservice = ApplicationService (
420421 "abcd" ,
@@ -443,9 +444,9 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
443444 self .auth_blocking ._limit_usage_by_mau = True
444445 self .auth_blocking ._track_appservice_user_ips = True
445446
446- self .store .get_monthly_active_count = simple_async_mock ( 100 )
447- self .store .user_last_seen_monthly_active = simple_async_mock ( )
448- self .store .is_trial_user = simple_async_mock ( )
447+ self .store .get_monthly_active_count = AsyncMock ( return_value = 100 )
448+ self .store .user_last_seen_monthly_active = AsyncMock ( return_value = None )
449+ self .store .is_trial_user = AsyncMock ( return_value = False )
449450
450451 appservice = ApplicationService (
451452 "abcd" ,
@@ -473,7 +474,7 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(
473474 def test_reserved_threepid (self ) -> None :
474475 self .auth_blocking ._limit_usage_by_mau = True
475476 self .auth_blocking ._max_mau_value = 1
476- self .store .get_monthly_active_count = simple_async_mock ( 2 )
477+ self .store .get_monthly_active_count = AsyncMock ( return_value = 2 )
477478 threepid = {
"medium" :
"email" ,
"address" :
"[email protected] " }
478479 unknown_threepid = {
"medium" :
"email" ,
"address" :
"[email protected] " }
479480 self .auth_blocking ._mau_limits_reserved_threepids = [threepid ]
0 commit comments