@@ -1111,7 +1111,7 @@ def get_device_stream_token(self) -> int:
11111111 ...
11121112
11131113 async def claim_e2e_one_time_keys (
1114- self , query_list : Iterable [Tuple [str , str , str , int ]]
1114+ self , query_list : Collection [Tuple [str , str , str , int ]]
11151115 ) -> Tuple [
11161116 Dict [str , Dict [str , Dict [str , JsonDict ]]], List [Tuple [str , str , str , int ]]
11171117 ]:
@@ -1121,131 +1121,63 @@ async def claim_e2e_one_time_keys(
11211121 query_list: An iterable of tuples of (user ID, device ID, algorithm).
11221122
11231123 Returns:
1124- A tuple pf :
1124+ A tuple (results, missing) of :
11251125 A map of user ID -> a map device ID -> a map of key ID -> JSON.
11261126
1127- A copy of the input which has not been fulfilled.
1127+ A copy of the input which has not been fulfilled. The returned counts
1128+ may be less than the input counts. In this case, the returned counts
1129+ are the number of claims that were not fulfilled.
11281130 """
1129-
1130- @trace
1131- def _claim_e2e_one_time_key_simple (
1132- txn : LoggingTransaction ,
1133- user_id : str ,
1134- device_id : str ,
1135- algorithm : str ,
1136- count : int ,
1137- ) -> List [Tuple [str , str ]]:
1138- """Claim OTK for device for DBs that don't support RETURNING.
1139-
1140- Returns:
1141- A tuple of key name (algorithm + key ID) and key JSON, if an
1142- OTK was found.
1143- """
1144-
1145- sql = """
1146- SELECT key_id, key_json FROM e2e_one_time_keys_json
1147- WHERE user_id = ? AND device_id = ? AND algorithm = ?
1148- LIMIT ?
1149- """
1150-
1151- txn .execute (sql , (user_id , device_id , algorithm , count ))
1152- otk_rows = list (txn )
1153- if not otk_rows :
1154- return []
1155-
1156- self .db_pool .simple_delete_many_txn (
1157- txn ,
1158- table = "e2e_one_time_keys_json" ,
1159- column = "key_id" ,
1160- values = [otk_row [0 ] for otk_row in otk_rows ],
1161- keyvalues = {
1162- "user_id" : user_id ,
1163- "device_id" : device_id ,
1164- "algorithm" : algorithm ,
1165- },
1166- )
1167- self ._invalidate_cache_and_stream (
1168- txn , self .count_e2e_one_time_keys , (user_id , device_id )
1169- )
1170-
1171- return [
1172- (f"{ algorithm } :{ key_id } " , key_json ) for key_id , key_json in otk_rows
1173- ]
1174-
1175- @trace
1176- def _claim_e2e_one_time_key_returning (
1177- txn : LoggingTransaction ,
1178- user_id : str ,
1179- device_id : str ,
1180- algorithm : str ,
1181- count : int ,
1182- ) -> List [Tuple [str , str ]]:
1183- """Claim OTK for device for DBs that support RETURNING.
1184-
1185- Returns:
1186- A tuple of key name (algorithm + key ID) and key JSON, if an
1187- OTK was found.
1188- """
1189-
1190- # We can use RETURNING to do the fetch and DELETE in once step.
1191- sql = """
1192- DELETE FROM e2e_one_time_keys_json
1193- WHERE user_id = ? AND device_id = ? AND algorithm = ?
1194- AND key_id IN (
1195- SELECT key_id FROM e2e_one_time_keys_json
1196- WHERE user_id = ? AND device_id = ? AND algorithm = ?
1197- LIMIT ?
1198- )
1199- RETURNING key_id, key_json
1200- """
1201-
1202- txn .execute (
1203- sql ,
1204- (user_id , device_id , algorithm , user_id , device_id , algorithm , count ),
1205- )
1206- otk_rows = list (txn )
1207- if not otk_rows :
1208- return []
1209-
1210- self ._invalidate_cache_and_stream (
1211- txn , self .count_e2e_one_time_keys , (user_id , device_id )
1212- )
1213-
1214- return [
1215- (f"{ algorithm } :{ key_id } " , key_json ) for key_id , key_json in otk_rows
1216- ]
1217-
12181131 results : Dict [str , Dict [str , Dict [str , JsonDict ]]] = {}
12191132 missing : List [Tuple [str , str , str , int ]] = []
1220- for user_id , device_id , algorithm , count in query_list :
1221- if self .database_engine .supports_returning :
1222- # If we support RETURNING clause we can use a single query that
1223- # allows us to use autocommit mode.
1224- _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
1225- db_autocommit = True
1226- else :
1227- _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
1228- db_autocommit = False
1133+ if isinstance (self .database_engine , PostgresEngine ):
1134+ # If we can use execute_values we can use a single batch query
1135+ # in autocommit mode.
1136+ unfulfilled_claim_counts : Dict [Tuple [str , str , str ], int ] = {}
1137+ for user_id , device_id , algorithm , count in query_list :
1138+ unfulfilled_claim_counts [user_id , device_id , algorithm ] = count
12291139
1230- claim_rows = await self .db_pool .runInteraction (
1140+ bulk_claims = await self .db_pool .runInteraction (
12311141 "claim_e2e_one_time_keys" ,
1232- _claim_e2e_one_time_key ,
1233- user_id ,
1234- device_id ,
1235- algorithm ,
1236- count ,
1237- db_autocommit = db_autocommit ,
1142+ self ._claim_e2e_one_time_keys_bulk ,
1143+ query_list ,
1144+ db_autocommit = True ,
12381145 )
1239- if claim_rows :
1146+
1147+ for user_id , device_id , algorithm , key_id , key_json in bulk_claims :
12401148 device_results = results .setdefault (user_id , {}).setdefault (
12411149 device_id , {}
12421150 )
1243- for claim_row in claim_rows :
1244- device_results [claim_row [0 ]] = json_decoder .decode (claim_row [1 ])
1151+ device_results [f"{ algorithm } :{ key_id } " ] = json_decoder .decode (key_json )
1152+ unfulfilled_claim_counts [(user_id , device_id , algorithm )] -= 1
1153+
12451154 # Did we get enough OTKs?
1246- count -= len (claim_rows )
1247- if count :
1248- missing .append ((user_id , device_id , algorithm , count ))
1155+ missing = [
1156+ (user , device , alg , count )
1157+ for (user , device , alg ), count in unfulfilled_claim_counts .items ()
1158+ if count > 0
1159+ ]
1160+ else :
1161+ for user_id , device_id , algorithm , count in query_list :
1162+ claim_rows = await self .db_pool .runInteraction (
1163+ "claim_e2e_one_time_keys" ,
1164+ self ._claim_e2e_one_time_key_simple ,
1165+ user_id ,
1166+ device_id ,
1167+ algorithm ,
1168+ count ,
1169+ db_autocommit = False ,
1170+ )
1171+ if claim_rows :
1172+ device_results = results .setdefault (user_id , {}).setdefault (
1173+ device_id , {}
1174+ )
1175+ for claim_row in claim_rows :
1176+ device_results [claim_row [0 ]] = json_decoder .decode (claim_row [1 ])
1177+ # Did we get enough OTKs?
1178+ count -= len (claim_rows )
1179+ if count :
1180+ missing .append ((user_id , device_id , algorithm , count ))
12491181
12501182 return results , missing
12511183
@@ -1362,6 +1294,99 @@ async def _claim_e2e_fallback_keys_simple(
13621294
13631295 return results
13641296
1297+ @trace
1298+ def _claim_e2e_one_time_key_simple (
1299+ self ,
1300+ txn : LoggingTransaction ,
1301+ user_id : str ,
1302+ device_id : str ,
1303+ algorithm : str ,
1304+ count : int ,
1305+ ) -> List [Tuple [str , str ]]:
1306+ """Claim OTK for device for DBs that don't support RETURNING.
1307+
1308+ Returns:
1309+ A tuple of key name (algorithm + key ID) and key JSON, if an
1310+ OTK was found.
1311+ """
1312+
1313+ sql = """
1314+ SELECT key_id, key_json FROM e2e_one_time_keys_json
1315+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
1316+ LIMIT ?
1317+ """
1318+
1319+ txn .execute (sql , (user_id , device_id , algorithm , count ))
1320+ otk_rows = list (txn )
1321+ if not otk_rows :
1322+ return []
1323+
1324+ self .db_pool .simple_delete_many_txn (
1325+ txn ,
1326+ table = "e2e_one_time_keys_json" ,
1327+ column = "key_id" ,
1328+ values = [otk_row [0 ] for otk_row in otk_rows ],
1329+ keyvalues = {
1330+ "user_id" : user_id ,
1331+ "device_id" : device_id ,
1332+ "algorithm" : algorithm ,
1333+ },
1334+ )
1335+ self ._invalidate_cache_and_stream (
1336+ txn , self .count_e2e_one_time_keys , (user_id , device_id )
1337+ )
1338+
1339+ return [(f"{ algorithm } :{ key_id } " , key_json ) for key_id , key_json in otk_rows ]
1340+
1341+ @trace
1342+ def _claim_e2e_one_time_keys_bulk (
1343+ self ,
1344+ txn : LoggingTransaction ,
1345+ query_list : Iterable [Tuple [str , str , str , int ]],
1346+ ) -> List [Tuple [str , str , str , str , str ]]:
1347+ """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.
1348+
1349+ Args:
1350+ query_list: Collection of tuples (user_id, device_id, algorithm, count)
1351+ as passed to claim_e2e_one_time_keys.
1352+
1353+ Returns:
1354+ A list of tuples (user_id, device_id, algorithm, key_id, key_json)
1355+ for each OTK claimed.
1356+ """
1357+ sql = """
1358+ WITH claims(user_id, device_id, algorithm, claim_count) AS (
1359+ VALUES ?
1360+ ), ranked_keys AS (
1361+ SELECT
1362+ user_id, device_id, algorithm, key_id, claim_count,
1363+ ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
1364+ FROM e2e_one_time_keys_json
1365+ JOIN claims USING (user_id, device_id, algorithm)
1366+ )
1367+ DELETE FROM e2e_one_time_keys_json k
1368+ WHERE (user_id, device_id, algorithm, key_id) IN (
1369+ SELECT user_id, device_id, algorithm, key_id
1370+ FROM ranked_keys
1371+ WHERE r <= claim_count
1372+ )
1373+ RETURNING user_id, device_id, algorithm, key_id, key_json;
1374+ """
1375+ otk_rows = cast (
1376+ List [Tuple [str , str , str , str , str ]], txn .execute_values (sql , query_list )
1377+ )
1378+
1379+ seen_user_device : Set [Tuple [str , str ]] = set ()
1380+ for user_id , device_id , _ , _ , _ in otk_rows :
1381+ if (user_id , device_id ) in seen_user_device :
1382+ continue
1383+ seen_user_device .add ((user_id , device_id ))
1384+ self ._invalidate_cache_and_stream (
1385+ txn , self .count_e2e_one_time_keys , (user_id , device_id )
1386+ )
1387+
1388+ return otk_rows
1389+
13651390
13661391class EndToEndKeyStore (EndToEndKeyWorkerStore , SQLBaseStore ):
13671392 def __init__ (
0 commit comments