2121 AsyncGenerator ,
2222 Awaitable ,
2323 Callable ,
24- Coroutine ,
2524 Optional ,
26- TypeVar ,
27- Protocol ,
2825)
2926
3027from google .api_core import exceptions , gapic_v1
4744# Types needed only for Type Hints
4845if TYPE_CHECKING : # pragma: NO COVER
4946 import datetime
47+ from typing_extensions import TypeVar , ParamSpec , Concatenate
5048
5149 from google .cloud .firestore_v1 .async_stream_generator import AsyncStreamGenerator
5250 from google .cloud .firestore_v1 .base_document import DocumentSnapshot
5351 from google .cloud .firestore_v1 .query_profile import ExplainOptions
5452
55-
56- T = TypeVar ( "T" , bound = Callable [..., Any ] )
53+ T = TypeVar ( "T" )
54+ P = ParamSpec ( "P" )
5755
5856
5957class AsyncTransaction (async_batch .AsyncWriteBatch , BaseTransaction ):
@@ -267,13 +265,13 @@ class _AsyncTransactional(_BaseTransactional):
267265 """
268266
269267 def __init__ (
270- self , to_wrap : Callable [... , Awaitable [T ]]
268+ self , to_wrap : Callable [Concatenate [ AsyncTransaction , P ] , Awaitable [T ]]
271269 ) -> None :
272270 super (_AsyncTransactional , self ).__init__ (to_wrap )
273271
274272 async def _pre_commit (
275- self , transaction : AsyncTransaction , * args : Any , ** kwargs : Any
276- ) -> Coroutine :
273+ self , transaction : AsyncTransaction , * args : P . args , ** kwargs : P . kwargs
274+ ) -> T :
277275 """Begin transaction and call the wrapped coroutine.
278276
279277 Args:
@@ -301,9 +299,7 @@ async def _pre_commit(
301299 self .retry_id = self .current_id
302300 return await self .to_wrap (transaction , * args , ** kwargs )
303301
304- async def __call__ (
305- self , transaction : AsyncTransaction , * args : Any , ** kwargs : Any
306- ) -> T :
302+ async def __call__ (self , transaction , * args : P .args , ** kwargs : P .kwargs ) -> T :
307303 """Execute the wrapped callable within a transaction.
308304
309305 Args:
@@ -330,7 +326,7 @@ async def __call__(
330326
331327 try :
332328 for attempt in range (transaction ._max_attempts ):
333- result = await self ._pre_commit (transaction , * args , ** kwargs )
329+ result : T = await self ._pre_commit (transaction , * args , ** kwargs )
334330 try :
335331 await transaction ._commit ()
336332 return result
@@ -354,12 +350,9 @@ async def __call__(
354350 raise
355351
356352
357- class WithAsyncTransaction (Protocol [T ]):
358- def __call__ (self , transaction : AsyncTransaction , * args : Any , ** kwargs : Any ) -> Awaitable [T ]: ...
359-
360353def async_transactional (
361- to_wrap : Callable [... , Awaitable [T ]]
362- ) -> WithAsyncTransaction [ T ]:
354+ to_wrap : Callable [Concatenate [ AsyncTransaction , P ] , Awaitable [T ]]
355+ ) -> Callable [ Concatenate [ AsyncTransaction , P ], Awaitable [ T ] ]:
363356 """Decorate a callable so that it runs in a transaction.
364357
365358 Args:
0 commit comments