Skip to content

Commit 7912776

Browse files
authored
fix: use new default ssl context in all aiohttp requests (#271)
Adds the missing ssl kwarg usage in other parts of the code, missed in #268.
1 parent 63a3917 commit 7912776

File tree

5 files changed

+13
-12
lines changed

5 files changed

+13
-12
lines changed

examples/ai_horde_client/image/async_manual_client_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import aiohttp
99
from loguru import logger
1010

11-
from horde_sdk import ANON_API_KEY
11+
from horde_sdk import ANON_API_KEY, _default_sslcontext
1212
from horde_sdk.ai_horde_api import AIHordeAPIAsyncManualClient
1313
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusRequest
1414
from horde_sdk.generic_api.apimodels import RequestErrorResponse
@@ -90,7 +90,7 @@ async def main(apikey: str = ANON_API_KEY) -> None:
9090

9191
image_bytes = None
9292
# image_gen.img is a url, download it using aiohttp.
93-
async with aiohttp.ClientSession() as session, session.get(image_gen.img) as resp:
93+
async with aiohttp.ClientSession() as session, session.get(image_gen.img, ssl=_default_sslcontext) as resp:
9494
image_bytes = await resp.read()
9595

9696
if image_bytes is None:

horde_sdk/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
# isort: off
44
# We import dotenv first so that we can use it to load environment variables before importing anything else.
5+
import ssl
6+
import certifi
57
import dotenv
68

79
# If the current working directory contains a `.env` file, import the environment variables from it.
@@ -59,7 +61,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
5961

6062

6163
_dev_env_var_warnings()
62-
64+
_default_sslcontext = ssl.create_default_context(cafile=certifi.where())
6365

6466
from horde_sdk.consts import (
6567
PAYLOAD_HTTP_METHODS,
@@ -109,4 +111,5 @@ def _dev_env_var_warnings() -> None: # pragma: no cover
109111
"PROGRESS_LOGGER_LABEL",
110112
"COMPLETE_LOGGER_LABEL",
111113
"HordeException",
114+
"_default_sslcontext",
112115
]

horde_sdk/ai_horde_api/ai_horde_clients.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import requests
1919
from loguru import logger
2020

21-
from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL
21+
from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL, _default_sslcontext
2222
from horde_sdk.ai_horde_api.apimodels import (
2323
AIHordeHeartbeatRequest,
2424
AIHordeHeartbeatResponse,
@@ -79,7 +79,6 @@
7979
GenericAsyncHordeAPISession,
8080
GenericHordeAPIManualClient,
8181
GenericHordeAPISession,
82-
_default_sslcontext,
8382
)
8483

8584

@@ -1290,7 +1289,7 @@ async def download_image_from_generation(self, generation: ImageGeneration) -> t
12901289

12911290
image_bytes: bytes | None = None
12921291
if urllib.parse.urlparse(generation.img).scheme in ["http", "https"]:
1293-
async with self._aiohttp_session.get(generation.img) as response:
1292+
async with self._aiohttp_session.get(generation.img, ssl=_default_sslcontext) as response:
12941293
if response.status != 200: # pragma: no cover
12951294
logger.error(f"Error downloading image: {response.status}")
12961295
response.raise_for_status()
@@ -1326,7 +1325,7 @@ async def download_image_from_url(self, url: str) -> PIL.Image.Image:
13261325
if self._aiohttp_session is None:
13271326
raise RuntimeError("No aiohttp session provided but an async request was made.")
13281327

1329-
async with self._aiohttp_session.get(url) as response:
1328+
async with self._aiohttp_session.get(url, ssl=_default_sslcontext) as response:
13301329
if response.status != 200: # pragma: no cover
13311330
logger.error(f"Error downloading image: {response.status}")
13321331
response.raise_for_status()

horde_sdk/generic_api/apimodels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pydantic import BaseModel, ConfigDict, Field, field_validator
1414
from typing_extensions import override
1515

16+
from horde_sdk import _default_sslcontext
1617
from horde_sdk.consts import HTTPMethod, HTTPStatusCode
1718
from horde_sdk.generic_api.consts import ANON_API_KEY
1819
from horde_sdk.generic_api.endpoints import GENERIC_API_ENDPOINT_SUBPATH, url_with_path
@@ -256,7 +257,7 @@ class ResponseRequiringDownloadMixin(HordeAPIDataObject):
256257

257258
async def download_file_as_base64(self, client_session: aiohttp.ClientSession, url: str) -> str:
258259
"""Download a file and return the value as a base64 string."""
259-
async with client_session.get(url) as response:
260+
async with client_session.get(url, ssl=_default_sslcontext) as response:
260261
response.raise_for_status()
261262
return base64.b64encode(await response.read()).decode("utf-8")
262263

@@ -273,7 +274,7 @@ async def download_file_to_field_as_base64(
273274
url (str): The URL to download the file from.
274275
field_name (str): The name of the field to save the file to.
275276
"""
276-
async with client_session.get(url) as response:
277+
async with client_session.get(url, ssl=_default_sslcontext) as response:
277278
response.raise_for_status()
278279
setattr(self, field_name, base64.b64encode(await response.read()).decode("utf-8"))
279280

horde_sdk/generic_api/generic_clients.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,18 @@
44

55
import asyncio
66
import os
7-
import ssl
87
from abc import ABC
98
from ssl import SSLContext
109
from typing import Any, TypeVar
1110

1211
import aiohttp
13-
import certifi
1412
import requests
1513
from loguru import logger
1614
from pydantic import BaseModel, ValidationError
1715
from strenum import StrEnum
1816
from typing_extensions import override
1917

18+
from horde_sdk import _default_sslcontext
2019
from horde_sdk.ai_horde_api.exceptions import AIHordePayloadValidationError
2120
from horde_sdk.consts import HTTPMethod
2221
from horde_sdk.generic_api.apimodels import (
@@ -35,7 +34,6 @@
3534
GenericQueryFields,
3635
)
3736

38-
_default_sslcontext = ssl.create_default_context(cafile=certifi.where())
3937
"""The default SSL context to use for aiohttp requests."""
4038

4139

0 commit comments

Comments
 (0)