1717import abc
1818from typing import TYPE_CHECKING , Dict , Iterable , List , Optional , Tuple
1919
20+ import attr
2021from canonicaljson import encode_canonical_json
2122
2223from twisted .enterprise .adbapi import Connection
3334 from synapse .handlers .e2e_keys import SignatureListItem
3435
3536
37+ @attr .s
38+ class DeviceKeyLookupResult :
39+ """The type returned by _get_e2e_device_keys_and_signatures_txn"""
40+
41+ display_name = attr .ib (type = Optional [str ])
42+
43+ # the key data from e2e_device_keys_json. Typically includes fields like
44+ # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
45+ # key) and "signatures" (a signature of the structure by the ed25519 key)
46+ key_json = attr .ib (type = Optional [str ])
47+
48+ # cross-signing sigs
49+ signatures = attr .ib (type = Optional [Dict ], default = None )
50+
51+
3652class EndToEndKeyWorkerStore (SQLBaseStore ):
3753 async def get_e2e_device_keys_for_federation_query (
3854 self , user_id : str
@@ -61,17 +77,17 @@ def _get_e2e_device_keys_for_federation_query_txn(
6177 for device_id , device in user_devices .items ():
6278 result = {"device_id" : device_id }
6379
64- key_json = device .get ( " key_json" , None )
80+ key_json = device .key_json
6581 if key_json :
6682 result ["keys" ] = db_to_json (key_json )
6783
68- if "signatures" in device :
69- for sig_user_id , sigs in device [ " signatures" ] .items ():
84+ if device . signatures :
85+ for sig_user_id , sigs in device . signatures .items ():
7086 result ["keys" ].setdefault ("signatures" , {}).setdefault (
7187 sig_user_id , {}
7288 ).update (sigs )
7389
74- device_display_name = device .get ( "device_display_name" , None )
90+ device_display_name = device .display_name
7591 if device_display_name :
7692 result ["device_display_name" ] = device_display_name
7793
@@ -109,13 +125,13 @@ async def get_e2e_device_keys_for_cs_api(
109125 for user_id , device_keys in results .items ():
110126 rv [user_id ] = {}
111127 for device_id , device_info in device_keys .items ():
112- r = db_to_json (device_info .pop ( " key_json" ) )
128+ r = db_to_json (device_info .key_json )
113129 r ["unsigned" ] = {}
114- display_name = device_info [ "device_display_name" ]
130+ display_name = device_info . display_name
115131 if display_name is not None :
116132 r ["unsigned" ]["device_display_name" ] = display_name
117- if "signatures" in device_info :
118- for sig_user_id , sigs in device_info [ " signatures" ] .items ():
133+ if device_info . signatures :
134+ for sig_user_id , sigs in device_info . signatures .items ():
119135 r .setdefault ("signatures" , {}).setdefault (
120136 sig_user_id , {}
121137 ).update (sigs )
@@ -126,7 +142,7 @@ async def get_e2e_device_keys_for_cs_api(
126142 @trace
127143 def _get_e2e_device_keys_and_signatures_txn (
128144 self , txn , query_list , include_all_devices = False , include_deleted_devices = False
129- ) -> Dict [str , Dict [str , Optional [Dict ]]]:
145+ ) -> Dict [str , Dict [str , Optional [DeviceKeyLookupResult ]]]:
130146 set_tag ("include_all_devices" , include_all_devices )
131147 set_tag ("include_deleted_devices" , include_deleted_devices )
132148
@@ -161,7 +177,7 @@ def _get_e2e_device_keys_and_signatures_txn(
161177
162178 sql = (
163179 "SELECT user_id, device_id, "
164- " d.display_name AS device_display_name , "
180+ " d.display_name, "
165181 " k.key_json"
166182 " FROM devices d"
167183 " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@@ -172,13 +188,14 @@ def _get_e2e_device_keys_and_signatures_txn(
172188 )
173189
174190 txn .execute (sql , query_params )
175- rows = self .db_pool .cursor_to_dict (txn )
176191
177- result = {}
178- for row in rows :
192+ result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
193+ for ( user_id , device_id , display_name , key_json ) in txn :
179194 if include_deleted_devices :
180- deleted_devices .remove ((row ["user_id" ], row ["device_id" ]))
181- result .setdefault (row ["user_id" ], {})[row ["device_id" ]] = row
195+ deleted_devices .remove ((user_id , device_id ))
196+ result .setdefault (user_id , {})[device_id ] = DeviceKeyLookupResult (
197+ display_name , key_json
198+ )
182199
183200 if include_deleted_devices :
184201 for user_id , device_id in deleted_devices :
@@ -209,7 +226,10 @@ def _get_e2e_device_keys_and_signatures_txn(
209226 # note that target_device_result will be None for deleted devices.
210227 continue
211228
212- target_device_signatures = target_device_result .setdefault ("signatures" , {})
229+ target_device_signatures = target_device_result .signatures
230+ if target_device_signatures is None :
231+ target_device_signatures = target_device_result .signatures = {}
232+
213233 signing_user_signatures = target_device_signatures .setdefault (
214234 signing_user_id , {}
215235 )
0 commit comments