11import asyncio
2+ import logging
23from collections .abc import AsyncGenerator , Coroutine
34from dataclasses import dataclass
45from typing import Any , Final , TypeAlias
56
67from aiohttp import ClientConnectionError , ClientSession
78from tenacity import TryAgain , retry
89from tenacity .asyncio import AsyncRetrying
10+ from tenacity .before_sleep import before_sleep_log
911from tenacity .retry import retry_if_exception_type
1012from tenacity .stop import stop_after_delay
1113from tenacity .wait import wait_random_exponential
1214from yarl import URL
1315
14- from ...rest_responses import unwrap_envelope
16+ from ...rest_responses import unwrap_envelope_if_required
1517from .. import status
1618from .server import TaskGet , TaskId , TaskProgress , TaskStatus
1719
20+ _logger = logging .getLogger (__name__ )
21+
1822RequestBody : TypeAlias = Any
1923
2024_MINUTE : Final [int ] = 60 # in secs
2529 "wait" : wait_random_exponential (max = 20 ),
2630 "stop" : stop_after_delay (60 ),
2731 "reraise" : True ,
32+ "before_sleep" : before_sleep_log (_logger , logging .INFO ),
2833}
2934
3035
3136@retry (** _DEFAULT_AIOHTTP_RETRY_POLICY )
3237async def _start (session : ClientSession , url : URL , json : RequestBody | None ) -> TaskGet :
3338 async with session .post (url , json = json ) as response :
3439 response .raise_for_status ()
35- data , error = unwrap_envelope (await response .json ())
36- assert not error # nosec
37- assert data is not None # nosec
40+ data = unwrap_envelope_if_required (await response .json ())
3841 return TaskGet .model_validate (data )
3942
4043
@@ -54,9 +57,7 @@ async def _wait_for_completion(
5457 with attempt :
5558 async with session .get (status_url ) as response :
5659 response .raise_for_status ()
57- data , error = unwrap_envelope (await response .json ())
58- assert not error # nosec
59- assert data is not None # nosec
60+ data = unwrap_envelope_if_required (await response .json ())
6061 task_status = TaskStatus .model_validate (data )
6162 yield task_status .task_progress
6263 if not task_status .done :
@@ -81,20 +82,14 @@ async def _task_result(session: ClientSession, result_url: URL) -> Any:
8182 async with session .get (result_url ) as response :
8283 response .raise_for_status ()
8384 if response .status != status .HTTP_204_NO_CONTENT :
84- data , error = unwrap_envelope (await response .json ())
85- assert not error # nosec
86- assert data # nosec
87- return data
85+ return unwrap_envelope_if_required (await response .json ())
8886 return None
8987
9088
9189@retry (** _DEFAULT_AIOHTTP_RETRY_POLICY )
9290async def _abort_task (session : ClientSession , abort_url : URL ) -> None :
9391 async with session .delete (abort_url ) as response :
9492 response .raise_for_status ()
95- data , error = unwrap_envelope (await response .json ())
96- assert not error # nosec
97- assert not data # nosec
9893
9994
10095@dataclass (frozen = True )
0 commit comments