@@ -755,81 +755,145 @@ async def claim_e2e_one_time_keys(
755755 """
756756
757757 @trace
758- def _claim_e2e_one_time_keys (txn ):
759- sql = (
760- "SELECT key_id, key_json FROM e2e_one_time_keys_json"
761- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
762- " LIMIT 1"
758+ def _claim_e2e_one_time_key_simple (
759+ txn , user_id : str , device_id : str , algorithm : str
760+ ) -> Optional [Tuple [str , str ]]:
761+ """Claim OTK for device for DBs that don't support RETURNING.
762+
763+ Returns:
764+ A tuple of key name (algorithm + key ID) and key JSON, if an
765+ OTK was found.
766+ """
767+
768+ sql = """
769+ SELECT key_id, key_json FROM e2e_one_time_keys_json
770+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
771+ LIMIT 1
772+ """
773+
774+ txn .execute (sql , (user_id , device_id , algorithm ))
775+ otk_row = txn .fetchone ()
776+ if otk_row is None :
777+ return None
778+
779+ key_id , key_json = otk_row
780+
781+ self .db_pool .simple_delete_one_txn (
782+ txn ,
783+ table = "e2e_one_time_keys_json" ,
784+ keyvalues = {
785+ "user_id" : user_id ,
786+ "device_id" : device_id ,
787+ "algorithm" : algorithm ,
788+ "key_id" : key_id ,
789+ },
763790 )
764- fallback_sql = (
765- "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
766- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
767- " LIMIT 1"
791+ self ._invalidate_cache_and_stream (
792+ txn , self .count_e2e_one_time_keys , (user_id , device_id )
768793 )
769- result = {}
770- delete = []
771- used_fallbacks = []
772- for user_id , device_id , algorithm in query_list :
773- user_result = result .setdefault (user_id , {})
774- device_result = user_result .setdefault (device_id , {})
775- txn .execute (sql , (user_id , device_id , algorithm ))
776- otk_row = txn .fetchone ()
777- if otk_row is not None :
778- key_id , key_json = otk_row
779- device_result [algorithm + ":" + key_id ] = key_json
780- delete .append ((user_id , device_id , algorithm , key_id ))
781- else :
782- # no one-time key available, so see if there's a fallback
783- # key
784- txn .execute (fallback_sql , (user_id , device_id , algorithm ))
785- fallback_row = txn .fetchone ()
786- if fallback_row is not None :
787- key_id , key_json , used = fallback_row
788- device_result [algorithm + ":" + key_id ] = key_json
789- if not used :
790- used_fallbacks .append (
791- (user_id , device_id , algorithm , key_id )
792- )
793-
794- # drop any one-time keys that were claimed
795- sql = (
796- "DELETE FROM e2e_one_time_keys_json"
797- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
798- " AND key_id = ?"
794+
795+ return f"{ algorithm } :{ key_id } " , key_json
796+
797+ @trace
798+ def _claim_e2e_one_time_key_returning (
799+ txn , user_id : str , device_id : str , algorithm : str
800+ ) -> Optional [Tuple [str , str ]]:
801+ """Claim OTK for device for DBs that support RETURNING.
802+
803+ Returns:
804+ A tuple of key name (algorithm + key ID) and key JSON, if an
805+ OTK was found.
806+ """
807+
808+ # We can use RETURNING to do the fetch and DELETE in once step.
809+ sql = """
810+ DELETE FROM e2e_one_time_keys_json
811+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
812+ AND key_id IN (
813+ SELECT key_id FROM e2e_one_time_keys_json
814+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
815+ LIMIT 1
816+ )
817+ RETURNING key_id, key_json
818+ """
819+
820+ txn .execute (
821+ sql , (user_id , device_id , algorithm , user_id , device_id , algorithm )
799822 )
800- for user_id , device_id , algorithm , key_id in delete :
801- log_kv (
802- {
803- "message" : "Executing claim e2e_one_time_keys transaction on database."
804- }
805- )
806- txn .execute (sql , (user_id , device_id , algorithm , key_id ))
807- log_kv ({"message" : "finished executing and invalidating cache" })
808- self ._invalidate_cache_and_stream (
809- txn , self .count_e2e_one_time_keys , (user_id , device_id )
823+ otk_row = txn .fetchone ()
824+ if otk_row is None :
825+ return None
826+
827+ key_id , key_json = otk_row
828+ return f"{ algorithm } :{ key_id } " , key_json
829+
830+ results = {}
831+ for user_id , device_id , algorithm in query_list :
832+ if self .database_engine .supports_returning :
833+ # If we support RETURNING clause we can use a single query that
834+ # allows us to use autocommit mode.
835+ _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
836+ db_autocommit = True
837+ else :
838+ _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
839+ db_autocommit = False
840+
841+ row = await self .db_pool .runInteraction (
842+ "claim_e2e_one_time_keys" ,
843+ _claim_e2e_one_time_key ,
844+ user_id ,
845+ device_id ,
846+ algorithm ,
847+ db_autocommit = db_autocommit ,
848+ )
849+ if row :
850+ device_results = results .setdefault (user_id , {}).setdefault (
851+ device_id , {}
810852 )
811- # mark fallback keys as used
812- for user_id , device_id , algorithm , key_id in used_fallbacks :
813- self .db_pool .simple_update_txn (
814- txn ,
815- "e2e_fallback_keys_json" ,
816- {
853+ device_results [row [0 ]] = row [1 ]
854+ continue
855+
856+ # No one-time key available, so see if there's a fallback
857+ # key
858+ row = await self .db_pool .simple_select_one (
859+ table = "e2e_fallback_keys_json" ,
860+ keyvalues = {
861+ "user_id" : user_id ,
862+ "device_id" : device_id ,
863+ "algorithm" : algorithm ,
864+ },
865+ retcols = ("key_id" , "key_json" , "used" ),
866+ desc = "_get_fallback_key" ,
867+ allow_none = True ,
868+ )
869+ if row is None :
870+ continue
871+
872+ key_id = row ["key_id" ]
873+ key_json = row ["key_json" ]
874+ used = row ["used" ]
875+
876+ # Mark fallback key as used if not already.
877+ if not used :
878+ await self .db_pool .simple_update_one (
879+ table = "e2e_fallback_keys_json" ,
880+ keyvalues = {
817881 "user_id" : user_id ,
818882 "device_id" : device_id ,
819883 "algorithm" : algorithm ,
820884 "key_id" : key_id ,
821885 },
822- {"used" : True },
886+ updatevalues = {"used" : True },
887+ desc = "_get_fallback_key_set_used" ,
823888 )
824- self ._invalidate_cache_and_stream (
825- txn , self . get_e2e_unused_fallback_key_types , (user_id , device_id )
889+ await self .invalidate_cache_and_stream (
890+ " get_e2e_unused_fallback_key_types" , (user_id , device_id )
826891 )
827892
828- return result
893+ device_results = results .setdefault (user_id , {}).setdefault (device_id , {})
894+ device_results [f"{ algorithm } :{ key_id } " ] = key_json
829895
830- return await self .db_pool .runInteraction (
831- "claim_e2e_one_time_keys" , _claim_e2e_one_time_keys
832- )
896+ return results
833897
834898
835899class EndToEndKeyStore (EndToEndKeyWorkerStore , SQLBaseStore ):
0 commit comments