Skip to content

Commit d612b3d

Browse files
committed
ensure cancel_wait_task properly re-raise
1 parent 81ab701 commit d612b3d

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

packages/common-library/src/common_library/async_tools.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from inspect import isawaitable
99
from typing import Any, ParamSpec, TypeVar, overload
1010

11+
from servicelib.logging_utils import log_context
12+
1113
_logger = logging.getLogger(__name__)
1214

1315
R = TypeVar("R")
@@ -90,33 +92,26 @@ async def cancel_wait_task(
9092
CancelledError: raised ONLY if owner is being cancelled.
9193
"""
9294

93-
cancelling = task.cancel()
94-
if not cancelling:
95-
return # task could not be cancelled (either already done or something else)
96-
97-
assert task.cancelling() # nosec
98-
assert not task.cancelled() # nosec
99-
assert not task.done() # nosec
100-
95+
task.cancel()
10196
try:
102-
103-
await asyncio.shield(
104-
# NOTE shield ensures that cancellation of the caller function won't stop you
105-
# from observing the cancellation/finalization of task.
106-
asyncio.wait_for(task, timeout=max_delay)
107-
)
97+
with log_context(
98+
_logger, logging.DEBUG, f"Cancelling task {task.get_name()!r}"
99+
):
100+
await asyncio.shield(
101+
# NOTE shield ensures that cancellation of the caller function won't stop you
102+
# from observing the cancellation/finalization of task.
103+
asyncio.wait_for(task, timeout=max_delay)
104+
)
108105

109106
except asyncio.CancelledError:
110107
assert task.done() # nosec
111-
if asyncio.current_task().cancelling() > 0:
108+
current_task = asyncio.current_task()
109+
assert current_task is not None # nosec
110+
if current_task.cancelling() > 0:
112111
# owner function is being cancelled -> propagate cancellation
113112
raise
114113

115114
# else: task cancellation is complete, we can safely ignore it
116-
_logger.debug(
117-
"Task %s cancellation is complete",
118-
task.get_name(),
119-
)
120115

121116

122117
def delayed_start(

0 commit comments

Comments
 (0)