Skip to content

Commit b426579

Browse files
committed
moved back to ParamSpec implementation
1 parent 0689151 commit b426579

File tree

1 file changed

+10
-17
lines changed

1 file changed

+10
-17
lines changed

google/cloud/firestore_v1/async_transaction.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@
2121
AsyncGenerator,
2222
Awaitable,
2323
Callable,
24-
Coroutine,
2524
Optional,
26-
TypeVar,
27-
Protocol,
2825
)
2926

3027
from google.api_core import exceptions, gapic_v1
@@ -47,13 +44,14 @@
4744
# Types needed only for Type Hints
4845
if TYPE_CHECKING: # pragma: NO COVER
4946
import datetime
47+
from typing_extensions import TypeVar, ParamSpec, Concatenate
5048

5149
from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
5250
from google.cloud.firestore_v1.base_document import DocumentSnapshot
5351
from google.cloud.firestore_v1.query_profile import ExplainOptions
5452

55-
56-
T = TypeVar("T", bound=Callable[..., Any])
53+
T = TypeVar("T")
54+
P = ParamSpec("P")
5755

5856

5957
class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
@@ -267,13 +265,13 @@ class _AsyncTransactional(_BaseTransactional):
267265
"""
268266

269267
def __init__(
270-
self, to_wrap: Callable[..., Awaitable[T]]
268+
self, to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
271269
) -> None:
272270
super(_AsyncTransactional, self).__init__(to_wrap)
273271

274272
async def _pre_commit(
275-
self, transaction: AsyncTransaction, *args: Any, **kwargs: Any
276-
) -> Coroutine:
273+
self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs
274+
) -> T:
277275
"""Begin transaction and call the wrapped coroutine.
278276
279277
Args:
@@ -301,9 +299,7 @@ async def _pre_commit(
301299
self.retry_id = self.current_id
302300
return await self.to_wrap(transaction, *args, **kwargs)
303301

304-
async def __call__(
305-
self, transaction: AsyncTransaction, *args: Any, **kwargs: Any
306-
) -> T:
302+
async def __call__(self, transaction, *args: P.args, **kwargs: P.kwargs) -> T:
307303
"""Execute the wrapped callable within a transaction.
308304
309305
Args:
@@ -330,7 +326,7 @@ async def __call__(
330326

331327
try:
332328
for attempt in range(transaction._max_attempts):
333-
result = await self._pre_commit(transaction, *args, **kwargs)
329+
result: T = await self._pre_commit(transaction, *args, **kwargs)
334330
try:
335331
await transaction._commit()
336332
return result
@@ -354,12 +350,9 @@ async def __call__(
354350
raise
355351

356352

357-
class WithAsyncTransaction(Protocol[T]):
358-
def __call__(self, transaction: AsyncTransaction, *args: Any, **kwargs: Any) -> Awaitable[T]: ...
359-
360353
def async_transactional(
361-
to_wrap: Callable[..., Awaitable[T]]
362-
) -> WithAsyncTransaction[T]:
354+
to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
355+
) -> Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]:
363356
"""Decorate a callable so that it runs in a transaction.
364357
365358
Args:

0 commit comments

Comments
 (0)