|
1 | | -from typing import Final |
| 1 | +import datetime |
| 2 | +import logging |
| 3 | +from asyncio import CancelledError |
| 4 | +from collections.abc import AsyncGenerator, Awaitable |
| 5 | +from typing import Any, Final |
2 | 6 |
|
| 7 | +from attr import dataclass |
3 | 8 | from models_library.api_schemas_rpc_async_jobs.async_jobs import ( |
4 | 9 | AsyncJobGet, |
5 | 10 | AsyncJobId, |
|
9 | 14 | ) |
10 | 15 | from models_library.rabbitmq_basic_types import RPCMethodName, RPCNamespace |
11 | 16 | from pydantic import NonNegativeInt, TypeAdapter |
| 17 | +from tenacity import ( |
| 18 | + AsyncRetrying, |
| 19 | + TryAgain, |
| 20 | + before_sleep_log, |
| 21 | + retry, |
| 22 | + retry_if_exception_type, |
| 23 | + stop_after_delay, |
| 24 | + wait_fixed, |
| 25 | + wait_random_exponential, |
| 26 | +) |
12 | 27 |
|
| 28 | +from ....long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S |
| 29 | +from ....rabbitmq import RemoteMethodNotRegisteredError |
13 | 30 | from ... import RabbitMQRPCClient |
14 | 31 |
|
15 | 32 | _DEFAULT_TIMEOUT_S: Final[NonNegativeInt] = 30 |
16 | 33 |
|
17 | 34 | _RPC_METHOD_NAME_ADAPTER = TypeAdapter(RPCMethodName) |
18 | 35 |
|
| 36 | +_logger = logging.getLogger(__name__) |
| 37 | + |
19 | 38 |
|
20 | 39 | async def cancel( |
21 | 40 | rabbitmq_rpc_client: RabbitMQRPCClient, |
@@ -103,3 +122,110 @@ async def submit( |
103 | 122 | ) |
104 | 123 | assert isinstance(_result, AsyncJobGet) # nosec |
105 | 124 | return _result |
| 125 | + |
| 126 | + |
| 127 | +_DEFAULT_RPC_RETRY_POLICY: dict[str, Any] = { |
| 128 | + "retry": retry_if_exception_type(RemoteMethodNotRegisteredError), |
| 129 | + "wait": wait_random_exponential(max=20), |
| 130 | + "stop": stop_after_delay(60), |
| 131 | + "reraise": True, |
| 132 | + "before_sleep": before_sleep_log(_logger, logging.INFO), |
| 133 | +} |
| 134 | + |
| 135 | + |
| 136 | +@retry(**_DEFAULT_RPC_RETRY_POLICY) |
| 137 | +async def _wait_for_completion( |
| 138 | + rabbitmq_rpc_client: RabbitMQRPCClient, |
| 139 | + *, |
| 140 | + rpc_namespace: RPCNamespace, |
| 141 | + job_id: AsyncJobId, |
| 142 | + job_id_data: AsyncJobNameData, |
| 143 | + client_timeout: int, |
| 144 | +) -> AsyncGenerator[AsyncJobStatus, None]: |
| 145 | + try: |
| 146 | + async for attempt in AsyncRetrying( |
| 147 | + stop=stop_after_delay(client_timeout), |
| 148 | + reraise=True, |
| 149 | + retry=retry_if_exception_type(TryAgain), |
| 150 | + before_sleep=before_sleep_log(_logger, logging.DEBUG), |
| 151 | + wait=wait_fixed(DEFAULT_POLL_INTERVAL_S), |
| 152 | + ): |
| 153 | + with attempt: |
| 154 | + job_status = await status( |
| 155 | + rabbitmq_rpc_client, |
| 156 | + rpc_namespace=rpc_namespace, |
| 157 | + job_id=job_id, |
| 158 | + job_id_data=job_id_data, |
| 159 | + ) |
| 160 | + yield job_status |
| 161 | + if not job_status.done: |
| 162 | + msg = f"{job_status.job_id=}: '{job_status.progress=}'" |
| 163 | + raise TryAgain(msg) # noqa: TRY301 |
| 164 | + |
| 165 | + except TryAgain as exc: |
| 166 | + # this is a timeout |
| 167 | + msg = f"Long running task {job_id=}, calling to timed-out after {client_timeout} seconds" |
| 168 | + raise TimeoutError(msg) from exc |
| 169 | + |
| 170 | + |
| 171 | +@dataclass(frozen=True) |
| 172 | +class AsyncJobComposedResult: |
| 173 | + status: AsyncJobStatus |
| 174 | + _result: Awaitable[Any] | None = None |
| 175 | + |
| 176 | + @property |
| 177 | + def done(self) -> bool: |
| 178 | + return self.status.done |
| 179 | + |
| 180 | + async def result(self) -> Any: |
| 181 | + if not self._result: |
| 182 | + msg = "No result ready!" |
| 183 | + raise ValueError(msg) |
| 184 | + return await self._result |
| 185 | + |
| 186 | + |
| 187 | +async def submit_and_wait( |
| 188 | + rabbitmq_rpc_client: RabbitMQRPCClient, |
| 189 | + *, |
| 190 | + rpc_namespace: RPCNamespace, |
| 191 | + method_name: str, |
| 192 | + job_id_data: AsyncJobNameData, |
| 193 | + client_timeout: datetime.timedelta, |
| 194 | + **kwargs, |
| 195 | +) -> AsyncGenerator[AsyncJobComposedResult, None]: |
| 196 | + async_job_rpc_get = None |
| 197 | + try: |
| 198 | + async_job_rpc_get = await submit( |
| 199 | + rabbitmq_rpc_client, |
| 200 | + rpc_namespace=rpc_namespace, |
| 201 | + method_name=method_name, |
| 202 | + job_id_data=job_id_data, |
| 203 | + **kwargs, |
| 204 | + ) |
| 205 | + async for job_status in _wait_for_completion( |
| 206 | + rabbitmq_rpc_client, |
| 207 | + rpc_namespace=rpc_namespace, |
| 208 | + job_id=async_job_rpc_get.job_id, |
| 209 | + job_id_data=job_id_data, |
| 210 | + client_timeout=client_timeout, |
| 211 | + ): |
| 212 | + yield AsyncJobComposedResult(job_status) |
| 213 | + |
| 214 | + yield AsyncJobComposedResult( |
| 215 | + job_status, |
| 216 | + result( |
| 217 | + rabbitmq_rpc_client, |
| 218 | + rpc_namespace=rpc_namespace, |
| 219 | + job_id=async_job_rpc_get.job_id, |
| 220 | + job_id_data=job_id_data, |
| 221 | + ), |
| 222 | + ) |
| 223 | + except (TimeoutError, CancelledError): |
| 224 | + if async_job_rpc_get is not None: |
| 225 | + await cancel( |
| 226 | + rabbitmq_rpc_client, |
| 227 | + rpc_namespace=rpc_namespace, |
| 228 | + job_id=async_job_rpc_get.job_id, |
| 229 | + job_id_data=job_id_data, |
| 230 | + ) |
| 231 | + raise |
0 commit comments