Skip to content

Commit d1e3fc5

Browse files
committed
Allow raw dicts in StateStore.set_encryption_info
1 parent a209c3b commit d1e3fc5

File tree

6 files changed

+17
-10
lines changed

6 files changed

+17
-10
lines changed

mautrix/client/state_store/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import Awaitable
8+
from typing import Any, Awaitable
99
from abc import ABC, abstractmethod
1010

1111
from mautrix.types import (
@@ -135,7 +135,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent
135135

136136
@abstractmethod
137137
async def set_encryption_info(
138-
self, room_id: RoomID, content: RoomEncryptionStateEventContent
138+
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, any]
139139
) -> None:
140140
pass
141141

mautrix/client/state_store/asyncpg/store.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import NamedTuple
8+
from typing import Any, NamedTuple
99

1010
from mautrix.types import (
1111
Member,
@@ -14,6 +14,7 @@
1414
PowerLevelStateEventContent,
1515
RoomEncryptionStateEventContent,
1616
RoomID,
17+
Serializable,
1718
UserID,
1819
)
1920
from mautrix.util.async_db import Database, Scheme
@@ -242,10 +243,12 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent
242243
return RoomEncryptionStateEventContent.parse_json(row["encryption"])
243244

244245
async def set_encryption_info(
245-
self, room_id: RoomID, content: RoomEncryptionStateEventContent
246+
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
246247
) -> None:
247248
q = (
248249
"INSERT INTO mx_room_state (room_id, is_encrypted, encryption) VALUES ($1, true, $2) "
249250
"ON CONFLICT (room_id) DO UPDATE SET is_encrypted=true, encryption=$2"
250251
)
251-
await self.db.execute(q, room_id, content.json())
252+
await self.db.execute(
253+
q, room_id, content.json() if isinstance(content, Serializable) else content
254+
)

mautrix/client/state_store/file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8-
from typing import IO
8+
from typing import IO, Any
99
from pathlib import Path
1010

1111
from mautrix.types import (
@@ -55,7 +55,7 @@ async def set_members(
5555
self._time_limited_flush()
5656

5757
async def set_encryption_info(
58-
self, room_id: RoomID, content: RoomEncryptionStateEventContent
58+
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
5959
) -> None:
6060
await super().set_encryption_info(room_id, content)
6161
self._time_limited_flush()

mautrix/client/state_store/memory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent
187187
return self.encryption.get(room_id)
188188

189189
async def set_encryption_info(
190-
self, room_id: RoomID, content: RoomEncryptionStateEventContent
190+
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
191191
) -> None:
192+
if not isinstance(content, RoomEncryptionStateEventContent):
193+
content = RoomEncryptionStateEventContent.deserialize(content)
192194
self.encryption[room_id] = content

mautrix/client/state_store/sqlalchemy/mx_room_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def python_type(self) -> Type[Serializable]:
3232

3333
def process_bind_param(self, value: Serializable, dialect) -> str | None:
3434
if value is not None:
35-
return json.dumps(value.serialize())
35+
return json.dumps(value.serialize() if isinstance(value, Serializable) else value)
3636
return None
3737

3838
def process_result_value(self, value: str, dialect) -> Serializable | None:

mautrix/client/state_store/sqlalchemy/sqlstatestore.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
66
from __future__ import annotations
77

8+
from typing import Any
9+
810
from mautrix.types import (
911
Member,
1012
Membership,
@@ -174,7 +176,7 @@ async def get_encryption_info(self, room_id: RoomID) -> RoomEncryptionStateEvent
174176
return room.encryption
175177

176178
async def set_encryption_info(
177-
self, room_id: RoomID, content: RoomEncryptionStateEventContent
179+
self, room_id: RoomID, content: RoomEncryptionStateEventContent | dict[str, Any]
178180
) -> None:
179181
if not content:
180182
raise ValueError("content is empty")

0 commit comments

Comments
 (0)