1212 ClientSession ,
1313 ClientTimeout ,
1414)
15+ from multidict import MultiDict
1516from yarl import URL
1617
1718from .const import DEFAULT_TIMEOUT , ResponseType
2728 SupervisorTimeoutError ,
2829)
2930from .models .base import Response , ResultType
31+ from .utils .aiohttp import ChunkAsyncStreamIterator
3032
3133VERSION = metadata .version (__package__ )
3234
@@ -53,12 +55,33 @@ class _SupervisorClient:
5355 session : ClientSession | None = None
5456 _close_session : bool = field (default = False , init = False )
5557
58+ async def _raise_on_status (self , response : ClientResponse ) -> None :
59+ """Raise appropriate exception on status."""
60+ if response .status >= HTTPStatus .BAD_REQUEST .value :
61+ exc_type : type [SupervisorError ] = SupervisorError
62+ match response .status :
63+ case HTTPStatus .BAD_REQUEST :
64+ exc_type = SupervisorBadRequestError
65+ case HTTPStatus .UNAUTHORIZED :
66+ exc_type = SupervisorAuthenticationError
67+ case HTTPStatus .FORBIDDEN :
68+ exc_type = SupervisorForbiddenError
69+ case HTTPStatus .NOT_FOUND :
70+ exc_type = SupervisorNotFoundError
71+ case HTTPStatus .SERVICE_UNAVAILABLE :
72+ exc_type = SupervisorServiceUnavailableError
73+
74+ if is_json (response ):
75+ result = Response .from_json (await response .text ())
76+ raise exc_type (result .message , result .job_id )
77+ raise exc_type ()
78+
5679 async def _request (
5780 self ,
5881 method : HTTPMethod ,
5982 uri : str ,
6083 * ,
61- params : dict [str , str ] | None ,
84+ params : dict [str , str ] | MultiDict [ str ] | None ,
6285 response_type : ResponseType ,
6386 json : dict [str , Any ] | None = None ,
6487 data : Any = None ,
@@ -94,42 +117,28 @@ async def _request(
94117 self ._close_session = True
95118
96119 try :
97- async with self .session .request (
120+ response = await self .session .request (
98121 method .value ,
99122 url ,
100123 timeout = timeout ,
101124 headers = headers ,
102125 params = params ,
103126 json = json ,
104127 data = data ,
105- ) as response :
106- if response .status >= HTTPStatus .BAD_REQUEST .value :
107- exc_type : type [SupervisorError ] = SupervisorError
108- match response .status :
109- case HTTPStatus .BAD_REQUEST :
110- exc_type = SupervisorBadRequestError
111- case HTTPStatus .UNAUTHORIZED :
112- exc_type = SupervisorAuthenticationError
113- case HTTPStatus .FORBIDDEN :
114- exc_type = SupervisorForbiddenError
115- case HTTPStatus .NOT_FOUND :
116- exc_type = SupervisorNotFoundError
117- case HTTPStatus .SERVICE_UNAVAILABLE :
118- exc_type = SupervisorServiceUnavailableError
119-
120- if is_json (response ):
121- result = Response .from_json (await response .text ())
122- raise exc_type (result .message , result .job_id )
123- raise exc_type ()
124-
125- match response_type :
126- case ResponseType .JSON :
127- is_json (response , raise_on_fail = True )
128- return Response .from_json (await response .text ())
129- case ResponseType .TEXT :
130- return Response (ResultType .OK , await response .text ())
131- case _:
132- return Response (ResultType .OK )
128+ )
129+ await self ._raise_on_status (response )
130+ match response_type :
131+ case ResponseType .JSON :
132+ is_json (response , raise_on_fail = True )
133+ return Response .from_json (await response .text ())
134+ case ResponseType .TEXT :
135+ return Response (ResultType .OK , await response .text ())
136+ case ResponseType .STREAM :
137+ return Response (
138+ ResultType .OK , ChunkAsyncStreamIterator (response .content )
139+ )
140+ case _:
141+ return Response (ResultType .OK )
133142
134143 except (UnicodeDecodeError , ClientResponseError ) as err :
135144 raise SupervisorResponseError (
@@ -146,7 +155,7 @@ async def get(
146155 self ,
147156 uri : str ,
148157 * ,
149- params : dict [str , str ] | None = None ,
158+ params : dict [str , str ] | MultiDict [ str ] | None = None ,
150159 response_type : ResponseType = ResponseType .JSON ,
151160 timeout : ClientTimeout | None = DEFAULT_TIMEOUT ,
152161 ) -> Response :
@@ -163,7 +172,7 @@ async def post(
163172 self ,
164173 uri : str ,
165174 * ,
166- params : dict [str , str ] | None = None ,
175+ params : dict [str , str ] | MultiDict [ str ] | None = None ,
167176 response_type : ResponseType = ResponseType .NONE ,
168177 json : dict [str , Any ] | None = None ,
169178 data : Any = None ,
@@ -184,7 +193,7 @@ async def put(
184193 self ,
185194 uri : str ,
186195 * ,
187- params : dict [str , str ] | None = None ,
196+ params : dict [str , str ] | MultiDict [ str ] | None = None ,
188197 json : dict [str , Any ] | None = None ,
189198 timeout : ClientTimeout | None = DEFAULT_TIMEOUT ,
190199 ) -> Response :
@@ -202,7 +211,7 @@ async def delete(
202211 self ,
203212 uri : str ,
204213 * ,
205- params : dict [str , str ] | None = None ,
214+ params : dict [str , str ] | MultiDict [ str ] | None = None ,
206215 timeout : ClientTimeout | None = DEFAULT_TIMEOUT ,
207216 ) -> Response :
208217 """Handle a DELETE request to Supervisor."""
0 commit comments