@@ -563,7 +563,9 @@ async def on_federation_query_client_keys(
563563 return ret
564564
565565 async def claim_local_one_time_keys (
566- self , local_query : List [Tuple [str , str , str ]]
566+ self ,
567+ local_query : List [Tuple [str , str , str ]],
568+ always_include_fallback_keys : bool ,
567569 ) -> Iterable [Dict [str , Dict [str , Dict [str , JsonDict ]]]]:
568570 """Claim one time keys for local users.
569571
@@ -573,6 +575,7 @@ async def claim_local_one_time_keys(
573575
574576 Args:
575577 local_query: An iterable of tuples of (user ID, device ID, algorithm).
578+ always_include_fallback_keys: True to always include fallback keys.
576579
577580 Returns:
578581 An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
@@ -583,24 +586,73 @@ async def claim_local_one_time_keys(
583586 # If the application services have not provided any keys via the C-S
584587 # API, query it directly for one-time keys.
585588 if self ._query_appservices_for_otks :
589+ # TODO Should this query for fallback keys of uploaded OTKs if
590+ # always_include_fallback_keys is True? The MSC is ambiguous.
586591 (
587592 appservice_results ,
588593 not_found ,
589594 ) = await self ._appservice_handler .claim_e2e_one_time_keys (not_found )
590595 else :
591- appservice_results = []
596+ appservice_results = {}
597+
598+ # Calculate which user ID / device ID / algorithm tuples to get fallback
599+ # keys for. This can be either only missing results *or* all results
600+ # (which don't already have a fallback key).
601+ if always_include_fallback_keys :
602+ # Build the fallback query as any part of the original query where
603+ # the appservice didn't respond with a fallback key.
604+ fallback_query = []
605+
606+ # Iterate each item in the original query and search the results
607+ # from the appservice for that user ID / device ID. If it is found,
608+ # check if any of the keys match the requested algorithm & are a
609+ # fallback key.
610+ for user_id , device_id , algorithm in local_query :
611+ # Check if the appservice responded for this query.
612+ as_result = appservice_results .get (user_id , {}).get (device_id , {})
613+ found_otk = False
614+ for key_id , key_json in as_result .items ():
615+ if key_id .startswith (f"{ algorithm } :" ):
616+ # A OTK or fallback key was found for this query.
617+ found_otk = True
618+ # A fallback key was found for this query, no need to
619+ # query further.
620+ if key_json .get ("fallback" , False ):
621+ break
622+
623+ else :
624+ # No fallback key was found from appservices, query for it.
625+ # Only mark the fallback key as used if no OTK was found
626+ # (from either the database or appservices).
627+ mark_as_used = not found_otk and not any (
628+ key_id .startswith (f"{ algorithm } :" )
629+ for key_id in otk_results .get (user_id , {})
630+ .get (device_id , {})
631+ .keys ()
632+ )
633+ fallback_query .append ((user_id , device_id , algorithm , mark_as_used ))
634+
635+ else :
636+ # All fallback keys get marked as used.
637+ fallback_query = [
638+ (user_id , device_id , algorithm , True )
639+ for user_id , device_id , algorithm in not_found
640+ ]
592641
593642 # For each user that does not have a one-time keys available, see if
594643 # there is a fallback key.
595- fallback_results = await self .store .claim_e2e_fallback_keys (not_found )
644+ fallback_results = await self .store .claim_e2e_fallback_keys (fallback_query )
596645
597646 # Return the results in order, each item from the input query should
598647 # only appear once in the combined list.
599- return (otk_results , * appservice_results , fallback_results )
648+ return (otk_results , appservice_results , fallback_results )
600649
601650 @trace
602651 async def claim_one_time_keys (
603- self , query : Dict [str , Dict [str , Dict [str , str ]]], timeout : Optional [int ]
652+ self ,
653+ query : Dict [str , Dict [str , Dict [str , str ]]],
654+ timeout : Optional [int ],
655+ always_include_fallback_keys : bool ,
604656 ) -> JsonDict :
605657 local_query : List [Tuple [str , str , str ]] = []
606658 remote_queries : Dict [str , Dict [str , Dict [str , str ]]] = {}
@@ -617,15 +669,19 @@ async def claim_one_time_keys(
617669 set_tag ("local_key_query" , str (local_query ))
618670 set_tag ("remote_key_query" , str (remote_queries ))
619671
620- results = await self .claim_local_one_time_keys (local_query )
672+ results = await self .claim_local_one_time_keys (
673+ local_query , always_include_fallback_keys
674+ )
621675
622676 # A map of user ID -> device ID -> key ID -> key.
623677 json_result : Dict [str , Dict [str , Dict [str , JsonDict ]]] = {}
624678 for result in results :
625679 for user_id , device_keys in result .items ():
626680 for device_id , keys in device_keys .items ():
627681 for key_id , key in keys .items ():
628- json_result .setdefault (user_id , {})[device_id ] = {key_id : key }
682+ json_result .setdefault (user_id , {}).setdefault (
683+ device_id , {}
684+ ).update ({key_id : key })
629685
630686 # Remote failures.
631687 failures : Dict [str , JsonDict ] = {}
0 commit comments