Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 8a4c296

Browse files
authored
Clean up tests.test_visibility to remove legacy code. (#11495)
1 parent 49e1356 commit 8a4c296

File tree

3 files changed

+40
-203
lines changed

3 files changed

+40
-203
lines changed

changelog.d/11495.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Clean up `tests.test_visibility` to remove legacy code.

mypy.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ exclude = (?x)
123123
|tests/test_server.py
124124
|tests/test_state.py
125125
|tests/test_terms_auth.py
126-
|tests/test_visibility.py
127126
|tests/unittest.py
128127
|tests/util/caches/test_cached_call.py
129128
|tests/util/caches/test_deferred_cache.py

tests/test_visibility.py

Lines changed: 39 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,30 @@
1313
# limitations under the License.
1414
import logging
1515
from typing import Optional
16-
from unittest.mock import Mock
17-
18-
from twisted.internet import defer
19-
from twisted.internet.defer import succeed
2016

2117
from synapse.api.room_versions import RoomVersions
22-
from synapse.events import FrozenEvent
18+
from synapse.events import EventBase
19+
from synapse.types import JsonDict
2320
from synapse.visibility import filter_events_for_server
2421

25-
import tests.unittest
26-
from tests.utils import create_room, setup_test_homeserver
22+
from tests import unittest
23+
from tests.utils import create_room
2724

2825
logger = logging.getLogger(__name__)
2926

3027
TEST_ROOM_ID = "!TEST:ROOM"
3128

3229

33-
class FilterEventsForServerTestCase(tests.unittest.TestCase):
34-
@defer.inlineCallbacks
35-
def setUp(self):
36-
self.hs = yield setup_test_homeserver(self.addCleanup)
30+
class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
31+
def setUp(self) -> None:
32+
super(FilterEventsForServerTestCase, self).setUp()
3733
self.event_creation_handler = self.hs.get_event_creation_handler()
3834
self.event_builder_factory = self.hs.get_event_builder_factory()
3935
self.storage = self.hs.get_storage()
4036

41-
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
37+
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
4238

43-
@defer.inlineCallbacks
44-
def test_filtering(self):
39+
def test_filtering(self) -> None:
4540
#
4641
# The events to be filtered consist of 10 membership events (it doesn't
4742
# really matter if they are joins or leaves, so let's make them joins).
@@ -51,18 +46,20 @@ def test_filtering(self):
5146
#
5247

5348
# before we do that, we persist some other events to act as state.
54-
yield self.inject_visibility("@admin:hs", "joined")
49+
self.get_success(self._inject_visibility("@admin:hs", "joined"))
5550
for i in range(0, 10):
56-
yield self.inject_room_member("@resident%i:hs" % i)
51+
self.get_success(self._inject_room_member("@resident%i:hs" % i))
5752

5853
events_to_filter = []
5954

6055
for i in range(0, 10):
6156
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
62-
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
57+
evt = self.get_success(
58+
self._inject_room_member(user, extra_content={"a": "b"})
59+
)
6360
events_to_filter.append(evt)
6461

65-
filtered = yield defer.ensureDeferred(
62+
filtered = self.get_success(
6663
filter_events_for_server(self.storage, "test_server", events_to_filter)
6764
)
6865

@@ -75,34 +72,31 @@ def test_filtering(self):
7572
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
7673
self.assertEqual(filtered[i].content["a"], "b")
7774

78-
@defer.inlineCallbacks
79-
def test_erased_user(self):
75+
def test_erased_user(self) -> None:
8076
# 4 message events, from erased and unerased users, with a membership
8177
# change in the middle of them.
8278
events_to_filter = []
8379

84-
evt = yield self.inject_message("@unerased:local_hs")
80+
evt = self.get_success(self._inject_message("@unerased:local_hs"))
8581
events_to_filter.append(evt)
8682

87-
evt = yield self.inject_message("@erased:local_hs")
83+
evt = self.get_success(self._inject_message("@erased:local_hs"))
8884
events_to_filter.append(evt)
8985

90-
evt = yield self.inject_room_member("@joiner:remote_hs")
86+
evt = self.get_success(self._inject_room_member("@joiner:remote_hs"))
9187
events_to_filter.append(evt)
9288

93-
evt = yield self.inject_message("@unerased:local_hs")
89+
evt = self.get_success(self._inject_message("@unerased:local_hs"))
9490
events_to_filter.append(evt)
9591

96-
evt = yield self.inject_message("@erased:local_hs")
92+
evt = self.get_success(self._inject_message("@erased:local_hs"))
9793
events_to_filter.append(evt)
9894

9995
# the erasey user gets erased
100-
yield defer.ensureDeferred(
101-
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
102-
)
96+
self.get_success(self.hs.get_datastore().mark_user_erased("@erased:local_hs"))
10397

10498
# ... and the filtering happens.
105-
filtered = yield defer.ensureDeferred(
99+
filtered = self.get_success(
106100
filter_events_for_server(self.storage, "test_server", events_to_filter)
107101
)
108102

@@ -123,8 +117,7 @@ def test_erased_user(self):
123117
for i in (1, 4):
124118
self.assertNotIn("body", filtered[i].content)
125119

126-
@defer.inlineCallbacks
127-
def inject_visibility(self, user_id, visibility):
120+
def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
128121
content = {"history_visibility": visibility}
129122
builder = self.event_builder_factory.for_room_version(
130123
RoomVersions.V1,
@@ -137,18 +130,18 @@ def inject_visibility(self, user_id, visibility):
137130
},
138131
)
139132

140-
event, context = yield defer.ensureDeferred(
133+
event, context = self.get_success(
141134
self.event_creation_handler.create_new_client_event(builder)
142135
)
143-
yield defer.ensureDeferred(
144-
self.storage.persistence.persist_event(event, context)
145-
)
136+
self.get_success(self.storage.persistence.persist_event(event, context))
146137
return event
147138

148-
@defer.inlineCallbacks
149-
def inject_room_member(
150-
self, user_id, membership="join", extra_content: Optional[dict] = None
151-
):
139+
def _inject_room_member(
140+
self,
141+
user_id: str,
142+
membership: str = "join",
143+
extra_content: Optional[JsonDict] = None,
144+
) -> EventBase:
152145
content = {"membership": membership}
153146
content.update(extra_content or {})
154147
builder = self.event_builder_factory.for_room_version(
@@ -162,17 +155,16 @@ def inject_room_member(
162155
},
163156
)
164157

165-
event, context = yield defer.ensureDeferred(
158+
event, context = self.get_success(
166159
self.event_creation_handler.create_new_client_event(builder)
167160
)
168161

169-
yield defer.ensureDeferred(
170-
self.storage.persistence.persist_event(event, context)
171-
)
162+
self.get_success(self.storage.persistence.persist_event(event, context))
172163
return event
173164

174-
@defer.inlineCallbacks
175-
def inject_message(self, user_id, content=None):
165+
def _inject_message(
166+
self, user_id: str, content: Optional[JsonDict] = None
167+
) -> EventBase:
176168
if content is None:
177169
content = {"body": "testytest", "msgtype": "m.text"}
178170
builder = self.event_builder_factory.for_room_version(
@@ -185,164 +177,9 @@ def inject_message(self, user_id, content=None):
185177
},
186178
)
187179

188-
event, context = yield defer.ensureDeferred(
180+
event, context = self.get_success(
189181
self.event_creation_handler.create_new_client_event(builder)
190182
)
191183

192-
yield defer.ensureDeferred(
193-
self.storage.persistence.persist_event(event, context)
194-
)
184+
self.get_success(self.storage.persistence.persist_event(event, context))
195185
return event
196-
197-
@defer.inlineCallbacks
198-
def test_large_room(self):
199-
# see what happens when we have a large room with hundreds of thousands
200-
# of membership events
201-
202-
# As above, the events to be filtered consist of 10 membership events,
203-
# where one of them is for a user on the server we are filtering for.
204-
205-
import cProfile
206-
import pstats
207-
import time
208-
209-
# we stub out the store, because building up all that state the normal
210-
# way is very slow.
211-
test_store = _TestStore()
212-
213-
# our initial state is 100000 membership events and one
214-
# history_visibility event.
215-
room_state = []
216-
217-
history_visibility_evt = FrozenEvent(
218-
{
219-
"event_id": "$history_vis",
220-
"type": "m.room.history_visibility",
221-
"sender": "@resident_user_0:test.com",
222-
"state_key": "",
223-
"room_id": TEST_ROOM_ID,
224-
"content": {"history_visibility": "joined"},
225-
}
226-
)
227-
room_state.append(history_visibility_evt)
228-
test_store.add_event(history_visibility_evt)
229-
230-
for i in range(0, 100000):
231-
user = "@resident_user_%i:test.com" % (i,)
232-
evt = FrozenEvent(
233-
{
234-
"event_id": "$res_event_%i" % (i,),
235-
"type": "m.room.member",
236-
"state_key": user,
237-
"sender": user,
238-
"room_id": TEST_ROOM_ID,
239-
"content": {"membership": "join", "extra": "zzz,"},
240-
}
241-
)
242-
room_state.append(evt)
243-
test_store.add_event(evt)
244-
245-
events_to_filter = []
246-
for i in range(0, 10):
247-
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
248-
evt = FrozenEvent(
249-
{
250-
"event_id": "$evt%i" % (i,),
251-
"type": "m.room.member",
252-
"state_key": user,
253-
"sender": user,
254-
"room_id": TEST_ROOM_ID,
255-
"content": {"membership": "join", "extra": "zzz"},
256-
}
257-
)
258-
events_to_filter.append(evt)
259-
room_state.append(evt)
260-
261-
test_store.add_event(evt)
262-
test_store.set_state_ids_for_event(
263-
evt, {(e.type, e.state_key): e.event_id for e in room_state}
264-
)
265-
266-
pr = cProfile.Profile()
267-
pr.enable()
268-
269-
logger.info("Starting filtering")
270-
start = time.time()
271-
272-
storage = Mock()
273-
storage.main = test_store
274-
storage.state = test_store
275-
276-
filtered = yield defer.ensureDeferred(
277-
filter_events_for_server(test_store, "test_server", events_to_filter)
278-
)
279-
logger.info("Filtering took %f seconds", time.time() - start)
280-
281-
pr.disable()
282-
with open("filter_events_for_server.profile", "w+") as f:
283-
ps = pstats.Stats(pr, stream=f).sort_stats("cumulative")
284-
ps.print_stats()
285-
286-
# the result should be 5 redacted events, and 5 unredacted events.
287-
for i in range(0, 5):
288-
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
289-
self.assertNotIn("extra", filtered[i].content)
290-
291-
for i in range(5, 10):
292-
self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
293-
self.assertEqual(filtered[i].content["extra"], "zzz")
294-
295-
test_large_room.skip = "Disabled by default because it's slow"
296-
297-
298-
class _TestStore:
299-
"""Implements a few methods of the DataStore, so that we can test
300-
filter_events_for_server
301-
302-
"""
303-
304-
def __init__(self):
305-
# data for get_events: a map from event_id to event
306-
self.events = {}
307-
308-
# data for get_state_ids_for_events mock: a map from event_id to
309-
# a map from (type_state_key) -> event_id for the state at that
310-
# event
311-
self.state_ids_for_events = {}
312-
313-
def add_event(self, event):
314-
self.events[event.event_id] = event
315-
316-
def set_state_ids_for_event(self, event, state):
317-
self.state_ids_for_events[event.event_id] = state
318-
319-
def get_state_ids_for_events(self, events, types):
320-
res = {}
321-
include_memberships = False
322-
for (type, state_key) in types:
323-
if type == "m.room.history_visibility":
324-
continue
325-
if type != "m.room.member" or state_key is not None:
326-
raise RuntimeError(
327-
"Unimplemented: get_state_ids with type (%s, %s)"
328-
% (type, state_key)
329-
)
330-
include_memberships = True
331-
332-
if include_memberships:
333-
for event_id in events:
334-
res[event_id] = self.state_ids_for_events[event_id]
335-
336-
else:
337-
k = ("m.room.history_visibility", "")
338-
for event_id in events:
339-
hve = self.state_ids_for_events[event_id][k]
340-
res[event_id] = {k: hve}
341-
342-
return succeed(res)
343-
344-
def get_events(self, events):
345-
return succeed({event_id: self.events[event_id] for event_id in events})
346-
347-
def are_users_erased(self, users):
348-
return succeed({u: False for u in users})

0 commit comments

Comments
 (0)