Skip to content

Commit 5828fd5

Browse files
committed
Migrate to using aiohttp
1 parent 9be4c08 commit 5828fd5

File tree

2 files changed

+511
-0
lines changed

2 files changed

+511
-0
lines changed

onvif/zeep_aiohttp.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
"""AIOHTTP transport for zeep."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import logging
7+
from typing import TYPE_CHECKING, Any
8+
9+
from zeep.transports import Transport
10+
from zeep.utils import get_version
11+
from zeep.wsdl.utils import etree_to_string
12+
13+
import aiohttp
14+
import httpx
15+
from aiohttp import ClientResponse, ClientSession, ClientTimeout, TCPConnector
16+
from requests import Response
17+
18+
if TYPE_CHECKING:
19+
from lxml.etree import _Element
20+
21+
_LOGGER = logging.getLogger(__name__)
22+
23+
24+
class AIOHTTPTransport(Transport):
25+
"""Async transport using aiohttp."""
26+
27+
def __init__(
28+
self,
29+
session: ClientSession | None = None,
30+
timeout: float = 300,
31+
operation_timeout: float | None = None,
32+
verify_ssl: bool = True,
33+
proxy: str | None = None,
34+
) -> None:
35+
"""
36+
Initialize the transport.
37+
38+
Args:
39+
session: The aiohttp ClientSession to use
40+
timeout: The default timeout for requests in seconds
41+
operation_timeout: The default timeout for operations in seconds
42+
verify_ssl: Whether to verify SSL certificates
43+
proxy: Proxy URL to use
44+
45+
"""
46+
super().__init__(
47+
cache=None,
48+
timeout=timeout,
49+
operation_timeout=operation_timeout,
50+
)
51+
52+
# Override parent's session with aiohttp session
53+
self.session = session
54+
self.timeout = timeout
55+
self.operation_timeout = operation_timeout
56+
self.verify_ssl = verify_ssl
57+
self.proxy = proxy
58+
self._close_session = session is None
59+
60+
async def __aenter__(self) -> AIOHTTPTransport:
61+
"""Enter async context."""
62+
if self.session is None:
63+
connector = TCPConnector(ssl=self.verify_ssl)
64+
timeout = ClientTimeout(total=self.timeout)
65+
self.session = ClientSession(connector=connector, timeout=timeout)
66+
return self
67+
68+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
69+
"""Exit async context."""
70+
if self._close_session and self.session:
71+
await self.session.close()
72+
self.session = None
73+
74+
async def aclose(self) -> None:
75+
"""Close the transport session."""
76+
if self.session:
77+
await self.session.close()
78+
79+
def _aiohttp_to_httpx_response(
80+
self, aiohttp_response: ClientResponse, content: bytes
81+
) -> httpx.Response:
82+
"""Convert aiohttp ClientResponse to httpx Response."""
83+
# Create httpx Response with the content
84+
httpx_response = httpx.Response(
85+
status_code=aiohttp_response.status,
86+
headers=httpx.Headers(aiohttp_response.headers),
87+
content=content,
88+
request=httpx.Request(
89+
method=aiohttp_response.method,
90+
url=str(aiohttp_response.url),
91+
),
92+
)
93+
94+
# Add encoding if available
95+
if aiohttp_response.charset:
96+
httpx_response._encoding = aiohttp_response.charset
97+
98+
# Store cookies if any
99+
if aiohttp_response.cookies:
100+
for cookie in aiohttp_response.cookies.values():
101+
httpx_response.cookies.set(
102+
cookie.key,
103+
cookie.value,
104+
domain=cookie.get("domain"),
105+
path=cookie.get("path"),
106+
)
107+
108+
return httpx_response
109+
110+
def _aiohttp_to_requests_response(
111+
self, aiohttp_response: ClientResponse, content: bytes
112+
) -> Response:
113+
"""Convert aiohttp ClientResponse directly to requests Response."""
114+
new = Response()
115+
new._content = content
116+
new.status_code = aiohttp_response.status
117+
new.headers = dict(aiohttp_response.headers)
118+
new.cookies = aiohttp_response.cookies
119+
new.encoding = aiohttp_response.charset
120+
return new
121+
122+
async def post(
123+
self, address: str, message: str, headers: dict[str, str]
124+
) -> httpx.Response:
125+
"""
126+
Perform async POST request.
127+
128+
Args:
129+
address: The URL to send the request to
130+
message: The message to send
131+
headers: HTTP headers to include
132+
133+
Returns:
134+
The httpx response object
135+
136+
"""
137+
if self.session is None:
138+
async with self:
139+
return await self._post(address, message, headers)
140+
return await self._post(address, message, headers)
141+
142+
async def _post(
143+
self, address: str, message: str, headers: dict[str, str]
144+
) -> httpx.Response:
145+
"""Internal POST implementation."""
146+
_LOGGER.debug("HTTP Post to %s:\n%s", address, message)
147+
148+
# Set default headers
149+
headers = headers or {}
150+
headers.setdefault("User-Agent", f"Zeep/{get_version()}")
151+
headers.setdefault("Content-Type", 'text/xml; charset="utf-8"')
152+
153+
# Determine timeout
154+
timeout = self.operation_timeout or self.timeout
155+
if timeout:
156+
client_timeout = ClientTimeout(total=timeout)
157+
else:
158+
client_timeout = None
159+
160+
# Handle both str and bytes
161+
if isinstance(message, str):
162+
data = message.encode("utf-8")
163+
else:
164+
data = message
165+
166+
try:
167+
response = await self.session.post(
168+
address,
169+
data=data,
170+
headers=headers,
171+
proxy=self.proxy,
172+
timeout=client_timeout,
173+
)
174+
response.raise_for_status()
175+
176+
# Read the content to log it
177+
content = await response.read()
178+
_LOGGER.debug(
179+
"HTTP Response from %s (status: %d):\n%s",
180+
address,
181+
response.status,
182+
content.decode("utf-8", errors="replace"),
183+
)
184+
185+
# Convert to httpx Response
186+
return self._aiohttp_to_httpx_response(response, content)
187+
188+
except TimeoutError as exc:
189+
raise TimeoutError(f"Request to {address} timed out") from exc
190+
except aiohttp.ClientError as exc:
191+
raise ConnectionError(f"Error connecting to {address}: {exc}") from exc
192+
193+
async def post_xml(
194+
self, address: str, envelope: _Element, headers: dict[str, str]
195+
) -> Response:
196+
"""
197+
Post XML envelope and return parsed response.
198+
199+
Args:
200+
address: The URL to send the request to
201+
envelope: The XML envelope to send
202+
headers: HTTP headers to include
203+
204+
Returns:
205+
A Response object compatible with zeep
206+
207+
"""
208+
message = etree_to_string(envelope)
209+
response = await self.post(address, message, headers)
210+
return self._httpx_to_requests_response(response)
211+
212+
async def get(
213+
self,
214+
address: str,
215+
params: dict[str, Any] | None = None,
216+
headers: dict[str, str] | None = None,
217+
) -> Response:
218+
"""
219+
Perform async GET request.
220+
221+
Args:
222+
address: The URL to send the request to
223+
params: Query parameters
224+
headers: HTTP headers to include
225+
226+
Returns:
227+
A Response object compatible with zeep
228+
229+
"""
230+
if self.session is None:
231+
async with self:
232+
return await self._get(address, params, headers)
233+
return await self._get(address, params, headers)
234+
235+
async def _get(
236+
self,
237+
address: str,
238+
params: dict[str, Any] | None = None,
239+
headers: dict[str, str] | None = None,
240+
) -> Response:
241+
"""Internal GET implementation."""
242+
_LOGGER.debug("HTTP Get from %s", address)
243+
244+
# Set default headers
245+
headers = headers or {}
246+
headers.setdefault("User-Agent", f"Zeep/{get_version()}")
247+
248+
# Determine timeout
249+
timeout = self.operation_timeout or self.timeout
250+
if timeout:
251+
client_timeout = ClientTimeout(total=timeout)
252+
else:
253+
client_timeout = None
254+
255+
try:
256+
response = await self.session.get(
257+
address,
258+
params=params,
259+
headers=headers,
260+
proxy=self.proxy,
261+
timeout=client_timeout,
262+
)
263+
response.raise_for_status()
264+
265+
# Read content
266+
content = await response.read()
267+
268+
_LOGGER.debug(
269+
"HTTP Response from %s (status: %d)",
270+
address,
271+
response.status,
272+
)
273+
274+
# Convert directly to requests.Response
275+
return self._aiohttp_to_requests_response(response, content)
276+
277+
except TimeoutError as exc:
278+
raise TimeoutError(f"Request to {address} timed out") from exc
279+
except aiohttp.ClientError as exc:
280+
raise ConnectionError(f"Error connecting to {address}: {exc}") from exc
281+
282+
def _httpx_to_requests_response(self, response: httpx.Response) -> Response:
283+
"""Convert an httpx.Response object to a requests.Response object"""
284+
body = response.read()
285+
286+
new = Response()
287+
new._content = body
288+
new.status_code = response.status_code
289+
new.headers = response.headers
290+
new.cookies = response.cookies
291+
new.encoding = response.encoding
292+
return new
293+
294+
def load(self, url: str) -> bytes:
295+
"""
296+
Load content from URL synchronously.
297+
298+
This method runs the async get method in a new event loop.
299+
300+
Args:
301+
url: The URL to load
302+
303+
Returns:
304+
The content as bytes
305+
306+
"""
307+
# Create a new event loop for sync operation
308+
loop = asyncio.new_event_loop()
309+
try:
310+
response = loop.run_until_complete(self.get(url))
311+
return response.content
312+
finally:
313+
loop.close()

0 commit comments

Comments
 (0)