Skip to content

Commit 4d4441b

Browse files
committed
Sync crypto store schema with mautrix-go
1 parent 0b67487 commit 4d4441b

File tree

4 files changed

+188
-21
lines changed

4 files changed

+188
-21
lines changed

mautrix/crypto/store/asyncpg/store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from mautrix.client.state_store import SyncStore
1414
from mautrix.client.state_store.asyncpg import PgStateStore
15+
from mautrix.errors import GroupSessionWithheldError
1516
from mautrix.types import (
1617
CrossSigner,
1718
CrossSigningUsage,
@@ -117,7 +118,7 @@ async def put_account(self, account: OlmAccount) -> None:
117118
await self.db.execute(
118119
q,
119120
self.account_id,
120-
self._device_id,
121+
self._device_id or "",
121122
account.shared,
122123
self._sync_token or "",
123124
pickle,
@@ -236,6 +237,10 @@ async def put_group_session(
236237
INSERT INTO crypto_megolm_inbound_session (
237238
session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id
238239
) VALUES ($1, $2, $3, $4, $5, $6, $7)
240+
ON CONFLICT (session_id, account_id) DO UPDATE
241+
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key,
242+
signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session,
243+
forwarding_chains=excluded.forwarding_chains
239244
"""
240245
try:
241246
await self.db.execute(
@@ -255,13 +260,15 @@ async def get_group_session(
255260
self, room_id: RoomID, session_id: SessionID
256261
) -> InboundGroupSession | None:
257262
q = """
258-
SELECT sender_key, signing_key, session, forwarding_chains
263+
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code
259264
FROM crypto_megolm_inbound_session
260265
WHERE room_id=$1 AND session_id=$2 AND account_id=$3
261266
"""
262267
row = await self.db.fetchrow(q, room_id, session_id, self.account_id)
263268
if row is None:
264269
return None
270+
if row["withheld_code"] is not None:
271+
raise GroupSessionWithheldError(session_id, row["withheld_code"])
265272
forwarding_chain = row["forwarding_chains"].split(",") if row["forwarding_chains"] else []
266273
return InboundGroupSession.from_pickle(
267274
row["session"],
@@ -275,7 +282,7 @@ async def get_group_session(
275282
async def has_group_session(self, room_id: RoomID, session_id: SessionID) -> bool:
276283
q = """
277284
SELECT COUNT(session) FROM crypto_megolm_inbound_session
278-
WHERE room_id=$1 AND session_id=$2 AND account_id=$3
285+
WHERE room_id=$1 AND session_id=$2 AND account_id=$3 AND session IS NOT NULL
279286
"""
280287
count = await self.db.fetchval(q, room_id, session_id, self.account_id)
281288
return count > 0

mautrix/crypto/store/asyncpg/upgrade.py

Lines changed: 171 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022 Tulir Asokan
1+
# Copyright (c) 2023 Tulir Asokan
22
#
33
# This Source Code Form is subject to the terms of the Mozilla Public
44
# License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -16,15 +16,15 @@
1616
)
1717

1818

19-
@upgrade_table.register(description="Latest revision", upgrades_to=6)
20-
async def upgrade_blank_to_v4(conn: Connection) -> None:
19+
@upgrade_table.register(description="Latest revision", upgrades_to=8)
20+
async def upgrade_blank_to_latest(conn: Connection) -> None:
2121
await conn.execute(
2222
"""CREATE TABLE IF NOT EXISTS crypto_account (
23-
account_id TEXT PRIMARY KEY,
24-
device_id TEXT,
25-
shared BOOLEAN NOT NULL,
26-
sync_token TEXT NOT NULL,
27-
account bytea NOT NULL
23+
account_id TEXT PRIMARY KEY,
24+
device_id TEXT NOT NULL,
25+
shared BOOLEAN NOT NULL,
26+
sync_token TEXT NOT NULL,
27+
account bytea NOT NULL
2828
)"""
2929
)
3030
await conn.execute(
@@ -68,16 +68,19 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
6868
)
6969
await conn.execute(
7070
"""CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
71-
account_id TEXT,
72-
session_id CHAR(43),
73-
sender_key CHAR(43) NOT NULL,
74-
signing_key CHAR(43) NOT NULL,
75-
room_id TEXT NOT NULL,
76-
session bytea NOT NULL,
77-
forwarding_chains TEXT NOT NULL,
71+
account_id TEXT,
72+
session_id CHAR(43),
73+
sender_key CHAR(43) NOT NULL,
74+
signing_key CHAR(43),
75+
room_id TEXT NOT NULL,
76+
session bytea,
77+
forwarding_chains TEXT,
78+
withheld_code TEXT,
79+
withheld_reason TEXT,
7880
PRIMARY KEY (account_id, session_id)
7981
)"""
8082
)
83+
# TODO chnge max_age to BIGINT
8184
await conn.execute(
8285
"""CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
8386
account_id TEXT,
@@ -97,8 +100,10 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
97100
"""CREATE TABLE crypto_cross_signing_keys (
98101
user_id TEXT,
99102
usage TEXT,
100-
key CHAR(43),
101-
first_seen_key CHAR(43),
103+
key CHAR(43) NOT NULL,
104+
105+
first_seen_key CHAR(43) NOT NULL,
106+
102107
PRIMARY KEY (user_id, usage)
103108
)"""
104109
)
@@ -108,7 +113,7 @@ async def upgrade_blank_to_v4(conn: Connection) -> None:
108113
signed_key TEXT,
109114
signer_user_id TEXT,
110115
signer_key TEXT,
111-
signature TEXT,
116+
signature CHAR(88) NOT NULL,
112117
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
113118
)"""
114119
)
@@ -250,3 +255,151 @@ async def upgrade_v6(conn: Connection) -> None:
250255
await conn.execute("UPDATE crypto_device SET trust=300 WHERE trust=1") # verified
251256
await conn.execute("UPDATE crypto_device SET trust=-100 WHERE trust=2") # blacklisted
252257
await conn.execute("UPDATE crypto_device SET trust=0 WHERE trust=3") # ignored -> unset
258+
259+
260+
@upgrade_table.register(
261+
description="Synchronize schema with mautrix-go", upgrades_to=8, transaction=False
262+
)
263+
async def upgrade_v8(conn: Connection, scheme: Scheme) -> None:
264+
if scheme == Scheme.POSTGRES:
265+
async with conn.transaction():
266+
await upgrade_v8_postgres(conn)
267+
else:
268+
await upgrade_v8_sqlite(conn)
269+
270+
271+
async def upgrade_v8_postgres(conn: Connection) -> None:
272+
await conn.execute("UPDATE crypto_account SET device_id='' WHERE device_id IS NULL")
273+
await conn.execute("ALTER TABLE crypto_account ALTER COLUMN device_id SET NOT NULL")
274+
275+
await conn.execute(
276+
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN signing_key DROP NOT NULL"
277+
)
278+
await conn.execute(
279+
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN session DROP NOT NULL"
280+
)
281+
await conn.execute(
282+
"ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN forwarding_chains DROP NOT NULL"
283+
)
284+
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_code TEXT")
285+
await conn.execute("ALTER TABLE crypto_megolm_inbound_session ADD COLUMN withheld_reason TEXT")
286+
287+
await conn.execute("DELETE FROM crypto_cross_signing_keys WHERE key IS NULL")
288+
await conn.execute(
289+
"UPDATE crypto_cross_signing_keys SET first_seen_key=key WHERE first_seen_key IS NULL"
290+
)
291+
await conn.execute("ALTER TABLE crypto_cross_signing_keys ALTER COLUMN key SET NOT NULL")
292+
await conn.execute(
293+
"ALTER TABLE crypto_cross_signing_keys ALTER COLUMN first_seen_key SET NOT NULL"
294+
)
295+
296+
await conn.execute("DELETE FROM crypto_cross_signing_signatures WHERE signature IS NULL")
297+
await conn.execute(
298+
"ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature SET NOT NULL"
299+
)
300+
301+
302+
async def upgrade_v8_sqlite(conn: Connection) -> None:
303+
await conn.execute("PRAGMA foreign_keys = OFF")
304+
async with conn.transaction():
305+
await conn.execute(
306+
"""CREATE TABLE new_crypto_account (
307+
account_id TEXT PRIMARY KEY,
308+
device_id TEXT NOT NULL,
309+
shared BOOLEAN NOT NULL,
310+
sync_token TEXT NOT NULL,
311+
account bytea NOT NULL
312+
)"""
313+
)
314+
await conn.execute(
315+
"""
316+
INSERT INTO new_crypto_account (account_id, device_id, shared, sync_token, account)
317+
SELECT account_id, COALESCE(device_id, ''), shared, sync_token, account
318+
FROM crypto_account
319+
"""
320+
)
321+
await conn.execute("DROP TABLE crypto_account")
322+
await conn.execute("ALTER TABLE new_crypto_account RENAME TO crypto_account")
323+
324+
await conn.execute(
325+
"""CREATE TABLE new_crypto_megolm_inbound_session (
326+
account_id TEXT,
327+
session_id CHAR(43),
328+
sender_key CHAR(43) NOT NULL,
329+
signing_key CHAR(43),
330+
room_id TEXT NOT NULL,
331+
session bytea,
332+
forwarding_chains TEXT,
333+
withheld_code TEXT,
334+
withheld_reason TEXT,
335+
PRIMARY KEY (account_id, session_id)
336+
)"""
337+
)
338+
await conn.execute(
339+
"""
340+
INSERT INTO new_crypto_megolm_inbound_session (
341+
account_id, session_id, sender_key, signing_key, room_id, session,
342+
forwarding_chains
343+
)
344+
SELECT account_id, session_id, sender_key, signing_key, room_id, session,
345+
forwarding_chains
346+
FROM crypto_megolm_inbound_session
347+
"""
348+
)
349+
await conn.execute("DROP TABLE crypto_megolm_inbound_session")
350+
await conn.execute(
351+
"ALTER TABLE new_crypto_megolm_inbound_session RENAME TO crypto_megolm_inbound_session"
352+
)
353+
354+
await conn.execute(
355+
"""CREATE TABLE new_crypto_cross_signing_keys (
356+
user_id TEXT,
357+
usage TEXT,
358+
key CHAR(43) NOT NULL,
359+
360+
first_seen_key CHAR(43) NOT NULL,
361+
362+
PRIMARY KEY (user_id, usage)
363+
)"""
364+
)
365+
await conn.execute(
366+
"""
367+
INSERT INTO new_crypto_cross_signing_keys (user_id, usage, key, first_seen_key)
368+
SELECT user_id, usage, key, COALESCE(first_seen_key, key)
369+
FROM crypto_cross_signing_keys
370+
WHERE key IS NOT NULL
371+
"""
372+
)
373+
await conn.execute("DROP TABLE crypto_cross_signing_keys")
374+
await conn.execute(
375+
"ALTER TABLE new_crypto_cross_signing_keys RENAME TO crypto_cross_signing_keys"
376+
)
377+
378+
await conn.execute(
379+
"""CREATE TABLE new_crypto_cross_signing_signatures (
380+
signed_user_id TEXT,
381+
signed_key TEXT,
382+
signer_user_id TEXT,
383+
signer_key TEXT,
384+
signature CHAR(88) NOT NULL,
385+
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
386+
)"""
387+
)
388+
await conn.execute(
389+
"""
390+
INSERT INTO new_crypto_cross_signing_signatures (
391+
signed_user_id, signed_key, signer_user_id, signer_key, signature
392+
)
393+
SELECT signed_user_id, signed_key, signer_user_id, signer_key, signature
394+
FROM crypto_cross_signing_signatures
395+
WHERE signature IS NOT NULL
396+
"""
397+
)
398+
await conn.execute("DROP TABLE crypto_cross_signing_signatures")
399+
await conn.execute(
400+
"ALTER TABLE new_crypto_cross_signing_signatures "
401+
"RENAME TO crypto_cross_signing_signatures"
402+
)
403+
404+
await conn.execute("PRAGMA foreign_key_check")
405+
await conn.execute("PRAGMA foreign_keys = ON")

mautrix/errors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
DeviceValidationError,
77
DuplicateMessageIndex,
88
EncryptionError,
9+
GroupSessionWithheldError,
910
MatchingSessionDecryptionError,
1011
MismatchingRoomError,
1112
SessionNotFound,

mautrix/errors/crypto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ class MatchingSessionDecryptionError(DecryptionError):
3636
pass
3737

3838

39+
class GroupSessionWithheldError(DecryptionError):
40+
def __init__(self, session_id: SessionID, withheld_code: str) -> None:
41+
super().__init__(f"Session ID {session_id} was withheld ({withheld_code})")
42+
self.withheld_code = withheld_code
43+
44+
3945
class SessionNotFound(DecryptionError):
4046
def __init__(self, session_id: SessionID, sender_key: IdentityKey | None = None) -> None:
4147
super().__init__(

0 commit comments

Comments
 (0)