1313# limitations under the License.
1414import logging
1515from typing import Optional
16- from unittest .mock import Mock
17-
18- from twisted .internet import defer
19- from twisted .internet .defer import succeed
2016
2117from synapse .api .room_versions import RoomVersions
22- from synapse .events import FrozenEvent
18+ from synapse .events import EventBase
19+ from synapse .types import JsonDict
2320from synapse .visibility import filter_events_for_server
2421
25- import tests . unittest
26- from tests .utils import create_room , setup_test_homeserver
22+ from tests import unittest
23+ from tests .utils import create_room
2724
2825logger = logging .getLogger (__name__ )
2926
3027TEST_ROOM_ID = "!TEST:ROOM"
3128
3229
33- class FilterEventsForServerTestCase (tests .unittest .TestCase ):
34- @defer .inlineCallbacks
35- def setUp (self ):
36- self .hs = yield setup_test_homeserver (self .addCleanup )
30+ class FilterEventsForServerTestCase (unittest .HomeserverTestCase ):
31+ def setUp (self ) -> None :
32+ super (FilterEventsForServerTestCase , self ).setUp ()
3733 self .event_creation_handler = self .hs .get_event_creation_handler ()
3834 self .event_builder_factory = self .hs .get_event_builder_factory ()
3935 self .storage = self .hs .get_storage ()
4036
41- yield defer . ensureDeferred (create_room (self .hs , TEST_ROOM_ID , "@someone:ROOM" ))
37+ self . get_success (create_room (self .hs , TEST_ROOM_ID , "@someone:ROOM" ))
4238
43- @defer .inlineCallbacks
44- def test_filtering (self ):
39+ def test_filtering (self ) -> None :
4540 #
4641 # The events to be filtered consist of 10 membership events (it doesn't
4742 # really matter if they are joins or leaves, so let's make them joins).
@@ -51,18 +46,20 @@ def test_filtering(self):
5146 #
5247
5348 # before we do that, we persist some other events to act as state.
54- yield self .inject_visibility ( "@admin:hs" , "joined" )
49+ self .get_success ( self . _inject_visibility ( "@admin:hs" , "joined" ) )
5550 for i in range (0 , 10 ):
56- yield self .inject_room_member ( "@resident%i:hs" % i )
51+ self .get_success ( self . _inject_room_member ( "@resident%i:hs" % i ) )
5752
5853 events_to_filter = []
5954
6055 for i in range (0 , 10 ):
6156 user = "@user%i:%s" % (i , "test_server" if i == 5 else "other_server" )
62- evt = yield self .inject_room_member (user , extra_content = {"a" : "b" })
57+ evt = self .get_success (
58+ self ._inject_room_member (user , extra_content = {"a" : "b" })
59+ )
6360 events_to_filter .append (evt )
6461
65- filtered = yield defer . ensureDeferred (
62+ filtered = self . get_success (
6663 filter_events_for_server (self .storage , "test_server" , events_to_filter )
6764 )
6865
@@ -75,34 +72,31 @@ def test_filtering(self):
7572 self .assertEqual (events_to_filter [i ].event_id , filtered [i ].event_id )
7673 self .assertEqual (filtered [i ].content ["a" ], "b" )
7774
78- @defer .inlineCallbacks
79- def test_erased_user (self ):
75+ def test_erased_user (self ) -> None :
8076 # 4 message events, from erased and unerased users, with a membership
8177 # change in the middle of them.
8278 events_to_filter = []
8379
84- evt = yield self .inject_message ( "@unerased:local_hs" )
80+ evt = self .get_success ( self . _inject_message ( "@unerased:local_hs" ) )
8581 events_to_filter .append (evt )
8682
87- evt = yield self .inject_message ( "@erased:local_hs" )
83+ evt = self .get_success ( self . _inject_message ( "@erased:local_hs" ) )
8884 events_to_filter .append (evt )
8985
90- evt = yield self .inject_room_member ( "@joiner:remote_hs" )
86+ evt = self .get_success ( self . _inject_room_member ( "@joiner:remote_hs" ) )
9187 events_to_filter .append (evt )
9288
93- evt = yield self .inject_message ( "@unerased:local_hs" )
89+ evt = self .get_success ( self . _inject_message ( "@unerased:local_hs" ) )
9490 events_to_filter .append (evt )
9591
96- evt = yield self .inject_message ( "@erased:local_hs" )
92+ evt = self .get_success ( self . _inject_message ( "@erased:local_hs" ) )
9793 events_to_filter .append (evt )
9894
9995 # the erasey user gets erased
100- yield defer .ensureDeferred (
101- self .hs .get_datastore ().mark_user_erased ("@erased:local_hs" )
102- )
96+ self .get_success (self .hs .get_datastore ().mark_user_erased ("@erased:local_hs" ))
10397
10498 # ... and the filtering happens.
105- filtered = yield defer . ensureDeferred (
99+ filtered = self . get_success (
106100 filter_events_for_server (self .storage , "test_server" , events_to_filter )
107101 )
108102
@@ -123,8 +117,7 @@ def test_erased_user(self):
123117 for i in (1 , 4 ):
124118 self .assertNotIn ("body" , filtered [i ].content )
125119
126- @defer .inlineCallbacks
127- def inject_visibility (self , user_id , visibility ):
120+ def _inject_visibility (self , user_id : str , visibility : str ) -> EventBase :
128121 content = {"history_visibility" : visibility }
129122 builder = self .event_builder_factory .for_room_version (
130123 RoomVersions .V1 ,
@@ -137,18 +130,18 @@ def inject_visibility(self, user_id, visibility):
137130 },
138131 )
139132
140- event , context = yield defer . ensureDeferred (
133+ event , context = self . get_success (
141134 self .event_creation_handler .create_new_client_event (builder )
142135 )
143- yield defer .ensureDeferred (
144- self .storage .persistence .persist_event (event , context )
145- )
136+ self .get_success (self .storage .persistence .persist_event (event , context ))
146137 return event
147138
148- @defer .inlineCallbacks
149- def inject_room_member (
150- self , user_id , membership = "join" , extra_content : Optional [dict ] = None
151- ):
139+ def _inject_room_member (
140+ self ,
141+ user_id : str ,
142+ membership : str = "join" ,
143+ extra_content : Optional [JsonDict ] = None ,
144+ ) -> EventBase :
152145 content = {"membership" : membership }
153146 content .update (extra_content or {})
154147 builder = self .event_builder_factory .for_room_version (
@@ -162,17 +155,16 @@ def inject_room_member(
162155 },
163156 )
164157
165- event , context = yield defer . ensureDeferred (
158+ event , context = self . get_success (
166159 self .event_creation_handler .create_new_client_event (builder )
167160 )
168161
169- yield defer .ensureDeferred (
170- self .storage .persistence .persist_event (event , context )
171- )
162+ self .get_success (self .storage .persistence .persist_event (event , context ))
172163 return event
173164
174- @defer .inlineCallbacks
175- def inject_message (self , user_id , content = None ):
165+ def _inject_message (
166+ self , user_id : str , content : Optional [JsonDict ] = None
167+ ) -> EventBase :
176168 if content is None :
177169 content = {"body" : "testytest" , "msgtype" : "m.text" }
178170 builder = self .event_builder_factory .for_room_version (
@@ -185,164 +177,9 @@ def inject_message(self, user_id, content=None):
185177 },
186178 )
187179
188- event , context = yield defer . ensureDeferred (
180+ event , context = self . get_success (
189181 self .event_creation_handler .create_new_client_event (builder )
190182 )
191183
192- yield defer .ensureDeferred (
193- self .storage .persistence .persist_event (event , context )
194- )
184+ self .get_success (self .storage .persistence .persist_event (event , context ))
195185 return event
196-
197- @defer .inlineCallbacks
198- def test_large_room (self ):
199- # see what happens when we have a large room with hundreds of thousands
200- # of membership events
201-
202- # As above, the events to be filtered consist of 10 membership events,
203- # where one of them is for a user on the server we are filtering for.
204-
205- import cProfile
206- import pstats
207- import time
208-
209- # we stub out the store, because building up all that state the normal
210- # way is very slow.
211- test_store = _TestStore ()
212-
213- # our initial state is 100000 membership events and one
214- # history_visibility event.
215- room_state = []
216-
217- history_visibility_evt = FrozenEvent (
218- {
219- "event_id" : "$history_vis" ,
220- "type" : "m.room.history_visibility" ,
221- "sender" : "@resident_user_0:test.com" ,
222- "state_key" : "" ,
223- "room_id" : TEST_ROOM_ID ,
224- "content" : {"history_visibility" : "joined" },
225- }
226- )
227- room_state .append (history_visibility_evt )
228- test_store .add_event (history_visibility_evt )
229-
230- for i in range (0 , 100000 ):
231- user = "@resident_user_%i:test.com" % (i ,)
232- evt = FrozenEvent (
233- {
234- "event_id" : "$res_event_%i" % (i ,),
235- "type" : "m.room.member" ,
236- "state_key" : user ,
237- "sender" : user ,
238- "room_id" : TEST_ROOM_ID ,
239- "content" : {"membership" : "join" , "extra" : "zzz," },
240- }
241- )
242- room_state .append (evt )
243- test_store .add_event (evt )
244-
245- events_to_filter = []
246- for i in range (0 , 10 ):
247- user = "@user%i:%s" % (i , "test_server" if i == 5 else "other_server" )
248- evt = FrozenEvent (
249- {
250- "event_id" : "$evt%i" % (i ,),
251- "type" : "m.room.member" ,
252- "state_key" : user ,
253- "sender" : user ,
254- "room_id" : TEST_ROOM_ID ,
255- "content" : {"membership" : "join" , "extra" : "zzz" },
256- }
257- )
258- events_to_filter .append (evt )
259- room_state .append (evt )
260-
261- test_store .add_event (evt )
262- test_store .set_state_ids_for_event (
263- evt , {(e .type , e .state_key ): e .event_id for e in room_state }
264- )
265-
266- pr = cProfile .Profile ()
267- pr .enable ()
268-
269- logger .info ("Starting filtering" )
270- start = time .time ()
271-
272- storage = Mock ()
273- storage .main = test_store
274- storage .state = test_store
275-
276- filtered = yield defer .ensureDeferred (
277- filter_events_for_server (test_store , "test_server" , events_to_filter )
278- )
279- logger .info ("Filtering took %f seconds" , time .time () - start )
280-
281- pr .disable ()
282- with open ("filter_events_for_server.profile" , "w+" ) as f :
283- ps = pstats .Stats (pr , stream = f ).sort_stats ("cumulative" )
284- ps .print_stats ()
285-
286- # the result should be 5 redacted events, and 5 unredacted events.
287- for i in range (0 , 5 ):
288- self .assertEqual (events_to_filter [i ].event_id , filtered [i ].event_id )
289- self .assertNotIn ("extra" , filtered [i ].content )
290-
291- for i in range (5 , 10 ):
292- self .assertEqual (events_to_filter [i ].event_id , filtered [i ].event_id )
293- self .assertEqual (filtered [i ].content ["extra" ], "zzz" )
294-
295- test_large_room .skip = "Disabled by default because it's slow"
296-
297-
298- class _TestStore :
299- """Implements a few methods of the DataStore, so that we can test
300- filter_events_for_server
301-
302- """
303-
304- def __init__ (self ):
305- # data for get_events: a map from event_id to event
306- self .events = {}
307-
308- # data for get_state_ids_for_events mock: a map from event_id to
309- # a map from (type_state_key) -> event_id for the state at that
310- # event
311- self .state_ids_for_events = {}
312-
313- def add_event (self , event ):
314- self .events [event .event_id ] = event
315-
316- def set_state_ids_for_event (self , event , state ):
317- self .state_ids_for_events [event .event_id ] = state
318-
319- def get_state_ids_for_events (self , events , types ):
320- res = {}
321- include_memberships = False
322- for (type , state_key ) in types :
323- if type == "m.room.history_visibility" :
324- continue
325- if type != "m.room.member" or state_key is not None :
326- raise RuntimeError (
327- "Unimplemented: get_state_ids with type (%s, %s)"
328- % (type , state_key )
329- )
330- include_memberships = True
331-
332- if include_memberships :
333- for event_id in events :
334- res [event_id ] = self .state_ids_for_events [event_id ]
335-
336- else :
337- k = ("m.room.history_visibility" , "" )
338- for event_id in events :
339- hve = self .state_ids_for_events [event_id ][k ]
340- res [event_id ] = {k : hve }
341-
342- return succeed (res )
343-
344- def get_events (self , events ):
345- return succeed ({event_id : self .events [event_id ] for event_id in events })
346-
347- def are_users_erased (self , users ):
348- return succeed ({u : False for u in users })
0 commit comments