|  | 
|  | 1 | +import logging | 
|  | 2 | +from typing import List, Tuple | 
|  | 3 | +from unittest.mock import Mock, patch | 
|  | 4 | + | 
|  | 5 | +from twisted.test.proto_helpers import MemoryReactor | 
|  | 6 | + | 
|  | 7 | +from synapse.api.constants import EventContentFields, EventTypes | 
|  | 8 | +from synapse.appservice import ApplicationService | 
|  | 9 | +from synapse.rest import admin | 
|  | 10 | +from synapse.rest.client import login, register, room, room_batch | 
|  | 11 | +from synapse.server import HomeServer | 
|  | 12 | +from synapse.types import JsonDict | 
|  | 13 | +from synapse.util import Clock | 
|  | 14 | + | 
|  | 15 | +from tests import unittest | 
|  | 16 | + | 
|  | 17 | +logger = logging.getLogger(__name__) | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +def _create_join_state_events_for_batch_send_request( | 
|  | 21 | +    virtual_user_ids: List[str], | 
|  | 22 | +    insert_time: int, | 
|  | 23 | +) -> List[JsonDict]: | 
|  | 24 | +    return [ | 
|  | 25 | +        { | 
|  | 26 | +            "type": EventTypes.Member, | 
|  | 27 | +            "sender": virtual_user_id, | 
|  | 28 | +            "origin_server_ts": insert_time, | 
|  | 29 | +            "content": { | 
|  | 30 | +                "membership": "join", | 
|  | 31 | +                "displayname": "display-name-for-%s" % (virtual_user_id,), | 
|  | 32 | +            }, | 
|  | 33 | +            "state_key": virtual_user_id, | 
|  | 34 | +        } | 
|  | 35 | +        for virtual_user_id in virtual_user_ids | 
|  | 36 | +    ] | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +def _create_message_events_for_batch_send_request( | 
|  | 40 | +    virtual_user_id: str, insert_time: int, count: int | 
|  | 41 | +) -> List[JsonDict]: | 
|  | 42 | +    return [ | 
|  | 43 | +        { | 
|  | 44 | +            "type": EventTypes.Message, | 
|  | 45 | +            "sender": virtual_user_id, | 
|  | 46 | +            "origin_server_ts": insert_time, | 
|  | 47 | +            "content": { | 
|  | 48 | +                "msgtype": "m.text", | 
|  | 49 | +                "body": "Historical %d" % (i), | 
|  | 50 | +                EventContentFields.MSC2716_HISTORICAL: True, | 
|  | 51 | +            }, | 
|  | 52 | +        } | 
|  | 53 | +        for i in range(count) | 
|  | 54 | +    ] | 
|  | 55 | + | 
|  | 56 | + | 
|  | 57 | +class RoomBatchTestCase(unittest.HomeserverTestCase): | 
|  | 58 | +    """Test importing batches of historical messages.""" | 
|  | 59 | + | 
|  | 60 | +    servlets = [ | 
|  | 61 | +        admin.register_servlets_for_client_rest_resource, | 
|  | 62 | +        room_batch.register_servlets, | 
|  | 63 | +        room.register_servlets, | 
|  | 64 | +        register.register_servlets, | 
|  | 65 | +        login.register_servlets, | 
|  | 66 | +    ] | 
|  | 67 | + | 
|  | 68 | +    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: | 
|  | 69 | +        config = self.default_config() | 
|  | 70 | + | 
|  | 71 | +        self.appservice = ApplicationService( | 
|  | 72 | +            token="i_am_an_app_service", | 
|  | 73 | +            hostname="test", | 
|  | 74 | +            id="1234", | 
|  | 75 | +            namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, | 
|  | 76 | +            # Note: this user does not have to match the regex above | 
|  | 77 | +            sender="@as_main:test", | 
|  | 78 | +        ) | 
|  | 79 | + | 
|  | 80 | +        mock_load_appservices = Mock(return_value=[self.appservice]) | 
|  | 81 | +        with patch( | 
|  | 82 | +            "synapse.storage.databases.main.appservice.load_appservices", | 
|  | 83 | +            mock_load_appservices, | 
|  | 84 | +        ): | 
|  | 85 | +            hs = self.setup_test_homeserver(config=config) | 
|  | 86 | +        return hs | 
|  | 87 | + | 
|  | 88 | +    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | 
|  | 89 | +        self.clock = clock | 
|  | 90 | +        self.storage = hs.get_storage() | 
|  | 91 | + | 
|  | 92 | +        self.virtual_user_id = self.register_appservice_user( | 
|  | 93 | +            "as_user_potato", self.appservice.token | 
|  | 94 | +        ) | 
|  | 95 | + | 
|  | 96 | +    def _create_test_room(self) -> Tuple[str, str, str, str]: | 
|  | 97 | +        room_id = self.helper.create_room_as( | 
|  | 98 | +            self.appservice.sender, tok=self.appservice.token | 
|  | 99 | +        ) | 
|  | 100 | + | 
|  | 101 | +        res_a = self.helper.send_event( | 
|  | 102 | +            room_id=room_id, | 
|  | 103 | +            type=EventTypes.Message, | 
|  | 104 | +            content={ | 
|  | 105 | +                "msgtype": "m.text", | 
|  | 106 | +                "body": "A", | 
|  | 107 | +            }, | 
|  | 108 | +            tok=self.appservice.token, | 
|  | 109 | +        ) | 
|  | 110 | +        event_id_a = res_a["event_id"] | 
|  | 111 | + | 
|  | 112 | +        res_b = self.helper.send_event( | 
|  | 113 | +            room_id=room_id, | 
|  | 114 | +            type=EventTypes.Message, | 
|  | 115 | +            content={ | 
|  | 116 | +                "msgtype": "m.text", | 
|  | 117 | +                "body": "B", | 
|  | 118 | +            }, | 
|  | 119 | +            tok=self.appservice.token, | 
|  | 120 | +        ) | 
|  | 121 | +        event_id_b = res_b["event_id"] | 
|  | 122 | + | 
|  | 123 | +        res_c = self.helper.send_event( | 
|  | 124 | +            room_id=room_id, | 
|  | 125 | +            type=EventTypes.Message, | 
|  | 126 | +            content={ | 
|  | 127 | +                "msgtype": "m.text", | 
|  | 128 | +                "body": "C", | 
|  | 129 | +            }, | 
|  | 130 | +            tok=self.appservice.token, | 
|  | 131 | +        ) | 
|  | 132 | +        event_id_c = res_c["event_id"] | 
|  | 133 | + | 
|  | 134 | +        return room_id, event_id_a, event_id_b, event_id_c | 
|  | 135 | + | 
|  | 136 | +    @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) | 
|  | 137 | +    def test_same_state_groups_for_whole_historical_batch(self): | 
|  | 138 | +        """Make sure that when using the `/batch_send` endpoint to import a | 
|  | 139 | +        bunch of historical messages, it re-uses the same `state_group` across | 
|  | 140 | +        the whole batch. This is an easy optimization to make sure we're getting | 
|  | 141 | +        right because the state for the whole batch is contained in | 
|  | 142 | +        `state_events_at_start` and can be shared across everything. | 
|  | 143 | +        """ | 
|  | 144 | + | 
|  | 145 | +        time_before_room = int(self.clock.time_msec()) | 
|  | 146 | +        room_id, event_id_a, _, _ = self._create_test_room() | 
|  | 147 | + | 
|  | 148 | +        channel = self.make_request( | 
|  | 149 | +            "POST", | 
|  | 150 | +            "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" | 
|  | 151 | +            % (room_id, event_id_a), | 
|  | 152 | +            content={ | 
|  | 153 | +                "events": _create_message_events_for_batch_send_request( | 
|  | 154 | +                    self.virtual_user_id, time_before_room, 3 | 
|  | 155 | +                ), | 
|  | 156 | +                "state_events_at_start": _create_join_state_events_for_batch_send_request( | 
|  | 157 | +                    [self.virtual_user_id], time_before_room | 
|  | 158 | +                ), | 
|  | 159 | +            }, | 
|  | 160 | +            access_token=self.appservice.token, | 
|  | 161 | +        ) | 
|  | 162 | +        self.assertEqual(channel.code, 200, channel.result) | 
|  | 163 | + | 
|  | 164 | +        # Get the historical event IDs that we just imported | 
|  | 165 | +        historical_event_ids = channel.json_body["event_ids"] | 
|  | 166 | +        self.assertEqual(len(historical_event_ids), 3) | 
|  | 167 | + | 
|  | 168 | +        # Fetch the state_groups | 
|  | 169 | +        state_group_map = self.get_success( | 
|  | 170 | +            self.storage.state.get_state_groups_ids(room_id, historical_event_ids) | 
|  | 171 | +        ) | 
|  | 172 | + | 
|  | 173 | +        # We expect all of the historical events to be using the same state_group | 
|  | 174 | +        # so there should only be a single state_group here! | 
|  | 175 | +        self.assertEqual( | 
|  | 176 | +            len(state_group_map.keys()), | 
|  | 177 | +            1, | 
|  | 178 | +            "Expected a single state_group to be returned by saw state_groups=%s" | 
|  | 179 | +            % (state_group_map.keys(),), | 
|  | 180 | +        ) | 
0 commit comments