44import itertools
55import logging
66import secrets
7- from typing import Dict , FrozenSet , List , NamedTuple , Optional , Set , Tuple , Type , TypeVar , Union , cast
7+ from typing import Awaitable , Dict , FrozenSet , List , NamedTuple , Optional , Set , Tuple , Type , TypeVar , Union , \
8+ cast
89from typing_extensions import assert_never
910
1011try :
@@ -247,6 +248,8 @@ def __init__(self) -> None:
247248 self .__pre_key_refill_threshold : int
248249 self .__identity_key_pair : IdentityKeyPair
249250 self .__synchronizing : bool
251+ self .__async_framework : AsyncFramework
252+ self .__signed_pre_key_management_task : Awaitable [None ]
250253
251254 @classmethod
252255 async def create (
@@ -349,6 +352,7 @@ async def create(
349352 self .__pre_key_refill_threshold = pre_key_refill_threshold
350353 self .__identity_key_pair = await IdentityKeyPair .get (storage )
351354 self .__synchronizing = True
355+ self .__async_framework = async_framework
352356
353357 try :
354358 self .__own_device_id = (await self .__storage .load_primitive ("/own_device_id" , int )).from_just ()
@@ -473,16 +477,15 @@ async def create(
473477 await self .purge_backend (namespace )
474478
475479 # Start signed pre key rotation management "in the background"
480+ signed_pre_key_management_coro = self .__manage_signed_pre_key_rotation (
481+ signed_pre_key_rotation_period ,
482+ async_framework
483+ )
484+
476485 if async_framework is AsyncFramework .ASYNCIO :
477- asyncio .ensure_future (self .__manage_signed_pre_key_rotation (
478- signed_pre_key_rotation_period ,
479- async_framework
480- ))
486+ self .__signed_pre_key_management_task = asyncio .ensure_future (signed_pre_key_management_coro )
481487 elif async_framework is AsyncFramework .TWISTED :
482- defer .ensureDeferred (self .__manage_signed_pre_key_rotation (
483- signed_pre_key_rotation_period ,
484- async_framework
485- ))
488+ self .__signed_pre_key_management_task = defer .ensureDeferred (signed_pre_key_management_coro )
486489 else :
487490 assert_never (async_framework )
488491
@@ -492,6 +495,30 @@ async def create(
492495
493496 return self
494497
498+ async def shutdown (self ) -> None :
499+ """
500+ Gracefully quit internal tasks.
501+ """
502+
503+ if self .__async_framework is AsyncFramework .ASYNCIO :
504+ asyncio_task = cast (asyncio .Future [None ], self .__signed_pre_key_management_task )
505+ try :
506+ asyncio_task .cancel ()
507+ await asyncio_task
508+ except asyncio .CancelledError :
509+ pass
510+
511+ elif self .__async_framework is AsyncFramework .TWISTED :
512+ twisted_task = cast (defer .Deferred [None ], self .__signed_pre_key_management_task )
513+ try :
514+ twisted_task .cancel ()
515+ await twisted_task
516+ except defer .CancelledError :
517+ pass
518+
519+ else :
520+ assert_never (self .__async_framework )
521+
495522 async def purge_backend (self , namespace : str ) -> None :
496523 """
497524 Purge a backend, removing both the online data (bundle, device list entry) and the offline data that
0 commit comments