|
1 | | -# Copyright (c) 2022 Tulir Asokan |
| 1 | +# Copyright (c) 2023 Tulir Asokan |
2 | 2 | # |
3 | 3 | # This Source Code Form is subject to the terms of the Mozilla Public |
4 | 4 | # License, v. 2.0. If a copy of the MPL was not distributed with this |
|
12 | 12 |
|
13 | 13 | import olm |
14 | 14 |
|
| 15 | +from mautrix.errors import MForbidden, MNotFound |
15 | 16 | from mautrix.types import ( |
16 | 17 | DeviceID, |
17 | 18 | EncryptionKeyAlgorithm, |
| 19 | + EventType, |
18 | 20 | IdentityKey, |
19 | 21 | KeyID, |
20 | 22 | RequestedKeyInfo, |
| 23 | + RoomEncryptionStateEventContent, |
21 | 24 | RoomID, |
| 25 | + RoomKeyEventContent, |
22 | 26 | SessionID, |
23 | 27 | SigningKey, |
24 | 28 | TrustState, |
@@ -82,6 +86,34 @@ def _mark_session_received(self, session_id: SessionID) -> None: |
82 | 86 | except KeyError: |
83 | 87 | return |
84 | 88 |
|
| 89 | + async def _fill_encryption_info(self, evt: RoomKeyEventContent) -> None: |
| 90 | + encryption_info = await self.state_store.get_encryption_info(evt.room_id) |
| 91 | + if not encryption_info: |
| 92 | + self.log.warning( |
| 93 | + f"Encryption info for {evt.room_id} not found in state store, fetching from server" |
| 94 | + ) |
| 95 | + try: |
| 96 | + encryption_info = await self.client.get_state_event( |
| 97 | + evt.room_id, EventType.ROOM_ENCRYPTION |
| 98 | + ) |
| 99 | + except (MNotFound, MForbidden) as e: |
| 100 | + self.log.warning( |
| 101 | + f"Failed to get encryption info for {evt.room_id} from server: {e}," |
| 102 | + " using defaults" |
| 103 | + ) |
| 104 | + encryption_info = RoomEncryptionStateEventContent() |
| 105 | + if not encryption_info: |
| 106 | + self.log.warning( |
| 107 | + f"Didn't find encryption info for {evt.room_id} on server either," |
| 108 | + " using defaults" |
| 109 | + ) |
| 110 | + encryption_info = RoomEncryptionStateEventContent() |
| 111 | + |
| 112 | + if not evt.beeper_max_age_ms: |
| 113 | + evt.beeper_max_age_ms = encryption_info.rotation_period_ms |
| 114 | + if not evt.beeper_max_messages: |
| 115 | + evt.beeper_max_messages = encryption_info.rotation_period_msgs |
| 116 | + |
85 | 117 |
|
86 | 118 | canonical_json = functools.partial( |
87 | 119 | json.dumps, ensure_ascii=False, separators=(",", ":"), sort_keys=True |
|
0 commit comments