|
18 | 18 | from twisted.internet.defer import Deferred, ensureDeferred |
19 | 19 | from twisted.test.proto_helpers import MemoryReactor |
20 | 20 |
|
| 21 | +from synapse.api.constants import EventTypes |
21 | 22 | from synapse.storage.state import StateFilter |
22 | | -from synapse.types import MutableStateMap, StateMap |
| 23 | +from synapse.types import StateMap |
23 | 24 | from synapse.util import Clock |
24 | 25 |
|
25 | 26 | from tests.unittest import HomeserverTestCase |
26 | 27 |
|
27 | 28 | if typing.TYPE_CHECKING: |
28 | 29 | from synapse.server import HomeServer |
29 | 30 |
|
| 31 | +# StateFilter for ALL non-m.room.member state events |
| 32 | +ALL_NON_MEMBERS_STATE_FILTER = StateFilter.freeze( |
| 33 | + types={EventTypes.Member: set()}, |
| 34 | + include_others=True, |
| 35 | +) |
| 36 | + |
| 37 | +FAKE_STATE = { |
| 38 | + (EventTypes.Member, "@alice:test"): "join", |
| 39 | + (EventTypes.Member, "@bob:test"): "leave", |
| 40 | + (EventTypes.Member, "@charlie:test"): "invite", |
| 41 | + ("test.type", "a"): "AAA", |
| 42 | + ("test.type", "b"): "BBB", |
| 43 | + ("other.event.type", "state.key"): "123", |
| 44 | +} |
| 45 | + |
30 | 46 |
|
31 | 47 | class StateGroupInflightCachingTestCase(HomeserverTestCase): |
32 | 48 | def prepare( |
@@ -65,24 +81,8 @@ def _complete_request_fake( |
65 | 81 | Assemble a fake database response and complete the database request. |
66 | 82 | """ |
67 | 83 |
|
68 | | - result: Dict[int, StateMap[str]] = {} |
69 | | - |
70 | | - for group in groups: |
71 | | - group_result: MutableStateMap[str] = {} |
72 | | - result[group] = group_result |
73 | | - |
74 | | - for state_type, state_keys in state_filter.types.items(): |
75 | | - if state_keys is None: |
76 | | - group_result[(state_type, "a")] = "xyz" |
77 | | - group_result[(state_type, "b")] = "xyz" |
78 | | - else: |
79 | | - for state_key in state_keys: |
80 | | - group_result[(state_type, state_key)] = "abc" |
81 | | - |
82 | | - if state_filter.include_others: |
83 | | - group_result[("other.event.type", "state.key")] = "123" |
84 | | - |
85 | | - d.callback(result) |
| 84 | + # Return a filtered copy of the fake state |
| 85 | + d.callback({group: state_filter.filter_state(FAKE_STATE) for group in groups}) |
86 | 86 |
|
87 | 87 | def test_duplicate_requests_deduplicated(self) -> None: |
88 | 88 | """ |
@@ -125,9 +125,159 @@ def test_duplicate_requests_deduplicated(self) -> None: |
125 | 125 | # Now we can complete the request |
126 | 126 | self._complete_request_fake(groups, sf, d) |
127 | 127 |
|
| 128 | + self.assertEqual(self.get_success(req1), FAKE_STATE) |
| 129 | + self.assertEqual(self.get_success(req2), FAKE_STATE) |
| 130 | + |
| 131 | + def test_smaller_request_deduplicated(self) -> None: |
| 132 | + """ |
| 133 | + Tests that duplicate requests for state are deduplicated. |
| 134 | +
|
| 135 | + This test: |
| 136 | + - requests some state (state group 42, 'all' state filter) |
| 137 | + - requests a subset of that state, before the first request finishes |
| 138 | + - checks to see that only one database query was made |
| 139 | + - completes the database query |
| 140 | + - checks that both requests see the correct retrieved state |
| 141 | + """ |
| 142 | + req1 = ensureDeferred( |
| 143 | + self.state_datastore._get_state_for_group_using_inflight_cache( |
| 144 | + 42, StateFilter.from_types((("test.type", None),)) |
| 145 | + ) |
| 146 | + ) |
| 147 | + self.pump(by=0.1) |
| 148 | + |
| 149 | + # This should have gone to the database |
| 150 | + self.assertEqual(len(self.get_state_group_calls), 1) |
| 151 | + self.assertFalse(req1.called) |
| 152 | + |
| 153 | + req2 = ensureDeferred( |
| 154 | + self.state_datastore._get_state_for_group_using_inflight_cache( |
| 155 | + 42, StateFilter.from_types((("test.type", "b"),)) |
| 156 | + ) |
| 157 | + ) |
| 158 | + self.pump(by=0.1) |
| 159 | + |
| 160 | + # No more calls should have gone to the database, because the second |
| 161 | + # request was already in the in-flight cache! |
| 162 | + self.assertEqual(len(self.get_state_group_calls), 1) |
| 163 | + self.assertFalse(req1.called) |
| 164 | + self.assertFalse(req2.called) |
| 165 | + |
| 166 | + groups, sf, d = self.get_state_group_calls[0] |
| 167 | + self.assertEqual(groups, (42,)) |
| 168 | + # The state filter is expanded internally for increased cache hit rate, |
| 169 | + # so we the database sees a wider state filter than requested. |
| 170 | + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) |
| 171 | + |
| 172 | + # Now we can complete the request |
| 173 | + self._complete_request_fake(groups, sf, d) |
| 174 | + |
| 175 | + self.assertEqual( |
| 176 | + self.get_success(req1), |
| 177 | + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, |
| 178 | + ) |
| 179 | + self.assertEqual(self.get_success(req2), {("test.type", "b"): "BBB"}) |
| 180 | + |
| 181 | + def test_partially_overlapping_request_deduplicated(self) -> None: |
| 182 | + """ |
| 183 | + Tests that partially-overlapping requests are partially deduplicated. |
| 184 | +
|
| 185 | + This test: |
| 186 | + - requests a single type of wildcard state |
| 187 | + (This is internally expanded to be all non-member state) |
| 188 | + - requests the entire state in parallel |
| 189 | + - checks to see that two database queries were made, but that the second |
| 190 | + one is only for member state. |
| 191 | + - completes the database queries |
| 192 | + - checks that both requests have the correct result. |
| 193 | + """ |
| 194 | + |
| 195 | + req1 = ensureDeferred( |
| 196 | + self.state_datastore._get_state_for_group_using_inflight_cache( |
| 197 | + 42, StateFilter.from_types((("test.type", None),)) |
| 198 | + ) |
| 199 | + ) |
| 200 | + self.pump(by=0.1) |
| 201 | + |
| 202 | + # This should have gone to the database |
| 203 | + self.assertEqual(len(self.get_state_group_calls), 1) |
| 204 | + self.assertFalse(req1.called) |
| 205 | + |
| 206 | + req2 = ensureDeferred( |
| 207 | + self.state_datastore._get_state_for_group_using_inflight_cache( |
| 208 | + 42, StateFilter.all() |
| 209 | + ) |
| 210 | + ) |
| 211 | + self.pump(by=0.1) |
| 212 | + |
| 213 | + # Because it only partially overlaps, this also went to the database |
| 214 | + self.assertEqual(len(self.get_state_group_calls), 2) |
| 215 | + self.assertFalse(req1.called) |
| 216 | + self.assertFalse(req2.called) |
| 217 | + |
| 218 | + # First request: |
| 219 | + groups, sf, d = self.get_state_group_calls[0] |
| 220 | + self.assertEqual(groups, (42,)) |
| 221 | + # The state filter is expanded internally for increased cache hit rate, |
| 222 | + # so we the database sees a wider state filter than requested. |
| 223 | + self.assertEqual(sf, ALL_NON_MEMBERS_STATE_FILTER) |
| 224 | + self._complete_request_fake(groups, sf, d) |
| 225 | + |
| 226 | + # Second request: |
| 227 | + groups, sf, d = self.get_state_group_calls[1] |
| 228 | + self.assertEqual(groups, (42,)) |
| 229 | + # The state filter is narrowed to only request membership state, because |
| 230 | + # the remainder of the state is already being queried in the first request! |
128 | 231 | self.assertEqual( |
129 | | - self.get_success(req1), {("other.event.type", "state.key"): "123"} |
| 232 | + sf, StateFilter.freeze({EventTypes.Member: None}, include_others=False) |
130 | 233 | ) |
| 234 | + self._complete_request_fake(groups, sf, d) |
| 235 | + |
| 236 | + # Check the results are correct |
131 | 237 | self.assertEqual( |
132 | | - self.get_success(req2), {("other.event.type", "state.key"): "123"} |
| 238 | + self.get_success(req1), |
| 239 | + {("test.type", "a"): "AAA", ("test.type", "b"): "BBB"}, |
133 | 240 | ) |
| 241 | + self.assertEqual(self.get_success(req2), FAKE_STATE) |
| 242 | + |
| 243 | + def test_in_flight_requests_stop_being_in_flight(self) -> None: |
| 244 | + """ |
| 245 | + Tests that in-flight request deduplication doesn't somehow 'hold on' |
| 246 | + to completed requests: once they're done, they're taken out of the |
| 247 | + in-flight cache. |
| 248 | + """ |
| 249 | + req1 = ensureDeferred( |
| 250 | + self.state_datastore._get_state_for_group_using_inflight_cache( |
| 251 | + 42, StateFilter.all() |
| 252 | + ) |
| 253 | + ) |
| 254 | + self.pump(by=0.1) |
| 255 | + |
| 256 | + # This should have gone to the database |
| 257 | + self.assertEqual(len(self.get_state_group_calls), 1) |
| 258 | + self.assertFalse(req1.called) |
| 259 | + |
| 260 | + # Complete the request right away. |
| 261 | + self._complete_request_fake(*self.get_state_group_calls[0]) |
| 262 | + self.assertTrue(req1.called) |
| 263 | + |
| 264 | + # Send off another request |
| 265 | + req2 = ensureDeferred( |
| 266 | + self.state_datastore._get_state_for_group_using_inflight_cache( |
| 267 | + 42, StateFilter.all() |
| 268 | + ) |
| 269 | + ) |
| 270 | + self.pump(by=0.1) |
| 271 | + |
| 272 | + # It should have gone to the database again, because the previous request |
| 273 | + # isn't in-flight and therefore isn't available for deduplication. |
| 274 | + self.assertEqual(len(self.get_state_group_calls), 2) |
| 275 | + self.assertFalse(req2.called) |
| 276 | + |
| 277 | + # Complete the request right away. |
| 278 | + self._complete_request_fake(*self.get_state_group_calls[1]) |
| 279 | + self.assertTrue(req2.called) |
| 280 | + groups, sf, d = self.get_state_group_calls[0] |
| 281 | + |
| 282 | + self.assertEqual(self.get_success(req1), FAKE_STATE) |
| 283 | + self.assertEqual(self.get_success(req2), FAKE_STATE) |
0 commit comments