Skip to content

Commit 90aedf8

Browse files
authored
Merge pull request #140 from mautrix/crypto-store-schema-sync
Sync crypto store schema with mautrix-go
2 parents 9ff58f9 + 251b8fc commit 90aedf8

File tree

5 files changed

+211
-32
lines changed

5 files changed

+211
-32
lines changed

mautrix/crypto/store/asyncpg/store.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from mautrix.client.state_store import SyncStore
1414
from mautrix.client.state_store.asyncpg import PgStateStore
15+
from mautrix.errors import GroupSessionWithheldError
1516
from mautrix.types import (
1617
CrossSigner,
1718
CrossSigningUsage,
@@ -117,7 +118,7 @@ async def put_account(self, account: OlmAccount) -> None:
117118
await self.db.execute(
118119
q,
119120
self.account_id,
120-
self._device_id,
121+
self._device_id or "",
121122
account.shared,
122123
self._sync_token or "",
123124
pickle,
@@ -236,6 +237,10 @@ async def put_group_session(
236237
INSERT INTO crypto_megolm_inbound_session (
237238
session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id
238239
) VALUES ($1, $2, $3, $4, $5, $6, $7)
240+
ON CONFLICT (session_id, account_id) DO UPDATE
241+
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key,
242+
signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session,
243+
forwarding_chains=excluded.forwarding_chains
239244
"""
240245
try:
241246
await self.db.execute(
@@ -255,13 +260,15 @@ async def get_group_session(
255260
self, room_id: RoomID, session_id: SessionID
256261
) -> InboundGroupSession | None:
257262
q = """
258-
SELECT sender_key, signing_key, session, forwarding_chains
263+
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code
259264
FROM crypto_megolm_inbound_session
260265
WHERE room_id=$1 AND session_id=$2 AND account_id=$3
261266
"""
262267
row = await self.db.fetchrow(q, room_id, session_id, self.account_id)
263268
if row is None:
264269
return None
270+
if row["withheld_code"] is not None:
271+
raise GroupSessionWithheldError(session_id, row["withheld_code"])
265272
forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else []
266273
return InboundGroupSession.from_pickle(
267274
row["session"],
@@ -275,16 +282,14 @@ async def get_group_session(
275282
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
276283
q = """
277284
SELECT COUNT(session) FROM crypto_megolm_inbound_session
278-
WHERE room_id=$1 AND session_id=$2 AND account_id=$3
285+
WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL
279286
"""
280287
count = await self.db.fetchval(q, room_id, session_id, self.account_id)
281288
return count > 0
282289

283290
async def add_outbound_group_session(self, session: OutboundGroupSession) -> None:
284291
pickle = session.pickle(self.pickle_key)
285-
max_age = session.max_age
286-
if self.db.scheme == Scheme.SQLITE:
287-
max_age = max_age.total_seconds()
292+
max_age = int(session.max_age.total_seconds() * 1000)
288293
q = """
289294
INSERT INTO crypto_megolm_outbound_session (
290295
room_id, session_id, session, shared, max_messages, message_count,
@@ -334,17 +339,14 @@ async def get_outbound_group_session(self, room_id: RoomID) -> OutboundGroupSess
334339
row = await self.db.fetchrow(q, room_id, self.account_id)
335340
if row is None:
336341
return None
337-
max_age = row["max_age"]
338-
if self.db.scheme == Scheme.SQLITE:
339-
max_age = timedelta(seconds=max_age)
340342
return OutboundGroupSession.from_pickle(
341343
row["session"],
342344
passphrase=self.pickle_key,
343345
room_id=row["room_id"],
344346
shared=row["shared"],
345347
max_messages=row["max_messages"],
346348
message_count=row["message_count"],
347-
max_age=max_age,
349+
max_age=timedelta(milliseconds=row["max_age"]),
348350
use_time=row["last_used"],
349351
creation_time=row["created_at"],
350352
)

mautrix/crypto/store/asyncpg/upgrade.py

Lines changed: 190 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 Tulir Asokan
1+
# Copyright (c) 2023 Tulir Asokan
22
#
33
# This Source Code Form is subject to the terms of the Mozilla Public
44
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -16,15 +16,15 @@
1616
)
1717

1818

19-
@upgrade_table.register(description="Latest revision", upgrades_to=6)
20-
async def upgrade_blank_to_v4(conn: Connection) -> None:
19+
@upgrade_table.register(description="Latest revision", upgrades_to=9)
20+
async def upgrade_blank_to_latest(conn: Connection) -> None:
2121
await conn.execute(
2222
"""CREATE TABLE IF NOT EXISTS crypto_account (
23-
account_id TEXT PRIMARY KEY,
24-
device_id TEXT,
25-
shared BOOLEAN NOT NULL,
26-
sync_token TEXT NOT NULL,
27-
account bytea NOT NULL
23+
account_id TEXT PRIMARY KEY,
24+
device_id TEXT NOT NULL,
25+
shared BOOLEAN NOT NULL,
26+
sync_token TEXT NOT NULL,
27+
account bytea NOT NULL
2828
)"""
2929
)
3030
await conn.execute(
@@ -68,13 +68,15 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
6868
)
6969
await conn.execute(
7070
"""CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
71-
account_id TEXT,
72-
session_id CHAR(43),
73-
sender_key CHAR(43) NOT NULL,
74-
signing_key CHAR(43) NOT NULL,
75-
room_id TEXT NOT NULL,
76-
session bytea NOT NULL,
77-
forwarding_chains TEXT NOT NULL,
71+
account_id TEXT,
72+
session_id CHAR(43),
73+
sender_key CHAR(43) NOT NULL,
74+
signing_key CHAR(43),
75+
room_id TEXT NOT NULL,
76+
session bytea,
77+
forwarding_chains TEXT,
78+
withheld_code TEXT,
79+
withheld_reason TEXT,
7880
PRIMARY KEY (account_id, session_id)
7981
)"""
8082
)
@@ -87,7 +89,7 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
8789
shared BOOLEAN NOT NULL,
8890
max_messages INTEGER NOT NULL,
8991
message_count INTEGER NOT NULL,
90-
max_age INTERVAL NOT NULL,
92+
max_age BIGINT NOT NULL,
9193
created_at timestamp NOT NULL,
9294
last_used timestamp NOT NULL,
9395
PRIMARY KEY (account_id, room_id)
@@ -97,8 +99,10 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
9799
"""CREATE TABLE crypto_cross_signing_keys (
98100
user_id TEXT,
99101
usage TEXT,
100-
key CHAR(43),
101-
first_seen_key CHAR(43),
102+
key CHAR(43) NOT NULL,
103+
104+
first_seen_key CHAR(43) NOT NULL,
105+
102106
PRIMARY KEY (user_id, usage)
103107
)"""
104108
)
@@ -108,7 +112,7 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
108112
signed_key TEXT,
109113
signer_user_id TEXT,
110114
signer_key TEXT,
111-
signature TEXT,
115+
signature CHAR(88) NOT NULL,
112116
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
113117
)"""
114118
)
@@ -162,7 +166,7 @@ async def upgrade_v2(conn: Connection, scheme: Scheme) -> None:
162166
shared BOOLEAN NOT NULL,
163167
max_messages INTEGER NOT NULL,
164168
message_count INTEGER NOT NULL,
165-
max_age INTERVAL NOT NULL,
169+
max_age BIGINT NOT NULL,
166170
created_at timestamp NOT NULL,
167171
last_used timestamp NOT NULL,
168172
PRIMARY KEY (account_id, room_id)
@@ -250,3 +254,169 @@ async def upgrade_v6(conn: Connection) -> None:
250254
await conn.execute("UPDATE crypto_device SET trust=300 WHERE trust=1") # verified
251255
await conn.execute("UPDATE crypto_device SET trust=-100 WHERE trust=2") # blacklisted
252256
await conn.execute("UPDATE crypto_device SET trust=0 WHERE trust=3") # ignored -> unset
257+
258+
259+
@upgrade_table.register(
260+
description="Synchronize schema with mautrix-go", upgrades_to=9, transaction=False
261+
)
262+
async def upgrade_v9(conn: Connection, scheme: Scheme) -> None:
263+
if scheme == Scheme.POSTGRES:
264+
async with conn.transaction():
265+
await upgrade_v9_postgres(conn)
266+
else:
267+
await upgrade_v9_sqlite(conn)
268+
269+
270+
# These two are never used because the previous one jumps from 6 to 9.
271+
@upgrade_table.register
272+
async def upgrade_noop_7_to_8(_: Connection) -> None:
273+
pass
274+
275+
276+
@upgrade_table.register
277+
async def upgrade_noop_8_to_9(_: Connection) -> None:
278+
pass
279+
280+
281+
async def upgrade_v9_postgres(conn: Connection) -> None:
282+
await conn.execute("UPDATE crypto_account SET device_id='' WHERE device_id IS NULL")
283+
await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id SET NOT NULL")
284+
285+
await conn.execute(
286+
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN signing_key DROP NOT NULL"
287+
)
288+
await conn.execute(
289+
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN session DROP NOT NULL"
290+
)
291+
await conn.execute(
292+
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN forwarding_chains DROP NOT NULL"
293+
)
294+
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_code TEXT")
295+
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_reason TEXT")
296+
297+
await conn.execute("DELETE FROM crypto_cross_signing_keys WHERE key IS NULL")
298+
await conn.execute(
299+
"UPDATE crypto_cross_signing_keys SET first_seen_key=key WHERE first_seen_key IS NULL"
300+
)
301+
await conn.execute("ALTER TABLE crypto_cross_signing_keys ALTER COLUMN key SET NOT NULL")
302+
await conn.execute(
303+
"ALTER TABLE crypto_cross_signing_keys ALTER COLUMN first_seen_key SET NOT NULL"
304+
)
305+
306+
await conn.execute("DELETE FROM crypto_cross_signing_signatures WHERE signature IS NULL")
307+
await conn.execute(
308+
"ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature SET NOT NULL"
309+
)
310+
311+
await conn.execute(
312+
"ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN max_age TYPE BIGINT "
313+
"USING (EXTRACT(EPOCH from max_age)*1000)::int"
314+
)
315+
316+
317+
async def upgrade_v9_sqlite(conn: Connection) -> None:
318+
await conn.execute("PRAGMA foreign_keys = OFF")
319+
async with conn.transaction():
320+
await conn.execute(
321+
"""CREATE TABLE new_crypto_account (
322+
account_id TEXT PRIMARY KEY,
323+
device_id TEXT NOT NULL,
324+
shared BOOLEAN NOT NULL,
325+
sync_token TEXT NOT NULL,
326+
account bytea NOT NULL
327+
)"""
328+
)
329+
await conn.execute(
330+
"""
331+
INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account)
332+
SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account
333+
FROM crypto_account
334+
"""
335+
)
336+
await conn.execute("DROP TABLE crypto_account")
337+
await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account")
338+
339+
await conn.execute(
340+
"""CREATE TABLE new_crypto_megolm_inbound_session (
341+
account_id TEXT,
342+
session_id CHAR(43),
343+
sender_key CHAR(43) NOT NULL,
344+
signing_key CHAR(43),
345+
room_id TEXT NOT NULL,
346+
session bytea,
347+
forwarding_chains TEXT,
348+
withheld_code TEXT,
349+
withheld_reason TEXT,
350+
PRIMARY KEY (account_id, session_id)
351+
)"""
352+
)
353+
await conn.execute(
354+
"""
355+
INSERT INTO new_crypto_megolm_inbound_session (
356+
account_id, session_id, sender_key, signing_key, room_id, session,
357+
forwarding_chains
358+
)
359+
SELECT account_id, session_id, sender_key, signing_key, room_id, session,
360+
forwarding_chains
361+
FROM crypto_megolm_inbound_session
362+
"""
363+
)
364+
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
365+
await conn.execute(
366+
"ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session"
367+
)
368+
369+
await conn.execute("UPDATE crypto_megolm_outbound_session SET max_age=max_age*1000")
370+
371+
await conn.execute(
372+
"""CREATE TABLE new_crypto_cross_signing_keys (
373+
user_id TEXT,
374+
usage TEXT,
375+
key CHAR(43) NOT NULL,
376+
377+
first_seen_key CHAR(43) NOT NULL,
378+
379+
PRIMARY KEY (user_id, usage)
380+
)"""
381+
)
382+
await conn.execute(
383+
"""
384+
INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
385+
SELECT user_id, usage, key, COALESCE(first_seen_key, key)
386+
FROM crypto_cross_signing_keys
387+
WHERE key IS NOT NULL
388+
"""
389+
)
390+
await conn.execute("DROP TABLE crypto_cross_signing_keys")
391+
await conn.execute(
392+
"ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys"
393+
)
394+
395+
await conn.execute(
396+
"""CREATE TABLE new_crypto_cross_signing_signatures (
397+
signed_user_id TEXT,
398+
signed_key TEXT,
399+
signer_user_id TEXT,
400+
signer_key TEXT,
401+
signature CHAR(88) NOT NULL,
402+
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
403+
)"""
404+
)
405+
await conn.execute(
406+
"""
407+
INSERT INTO new_crypto_cross_signing_signatures (
408+
signed_user_id, signed_key, signer_user_id, signer_key, signature
409+
)
410+
SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature
411+
FROM crypto_cross_signing_signatures
412+
WHERE signature IS NOT NULL
413+
"""
414+
)
415+
await conn.execute("DROP TABLE crypto_cross_signing_signatures")
416+
await conn.execute(
417+
"ALTER TABLE new_crypto_cross_signing_signatures "
418+
"RENAME TO crypto_cross_signing_signatures"
419+
)
420+
421+
await conn.execute("PRAGMA foreign_key_check")
422+
await conn.execute("PRAGMA foreign_keys = ON")

mautrix/errors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DeviceValidationError,
77
DuplicateMessageIndex,
88
EncryptionError,
9+
GroupSessionWithheldError,
910
MatchingSessionDecryptionError,
1011
MismatchingRoomError,
1112
SessionNotFound,

mautrix/errors/crypto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ class MatchingSessionDecryptionError(DecryptionError):
3636
pass
3737

3838

39+
class GroupSessionWithheldError(DecryptionError):
40+
def __init__(self, session_id: SessionID, withheld_code: str) -> None:
41+
super().__init__(f"Session ID {session_id} was withheld ({withheld_code})")
42+
self.withheld_code = withheld_code
43+
44+
3945
class SessionNotFound(DecryptionError):
4046
def __init__(self, session_id: SessionID, sender_key: IdentityKey | None = None) -> None:
4147
super().__init__(

mautrix/util/async_db/upgrade.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
UpgradeWithoutScheme = Callable[[LoggingConnection], Awaitable[Optional[int]]]
2222

2323

24-
async def noop_upgrade(_: LoggingConnection) -> None:
24+
async def noop_upgrade(_: LoggingConnection, _2: Scheme) -> None:
2525
pass
2626

2727

@@ -178,6 +178,6 @@ def _find_upgrade_table(fn: Upgrade) -> UpgradeTable:
178178

179179
def register_upgrade(index: int = -1, description: str = "") -> Callable[[Upgrade], Upgrade]:
180180
def actually_register(fn: Upgrade) -> Upgrade:
181-
return _find_upgrade_table(fn).register(index, description, fn)
181+
return _find_upgrade_table(fn).register(fn, index=index, description=description)
182182

183183
return actually_register

0 commit comments

Comments
 (0)