Skip to content

Commit c9c39b8

Browse files
morgandiverrezBenviiWauplin
authored
Bug - [InferenceClient] - use proxy set in var env (#2421)
* set trust_env in aiohttp at True * handle trust_env and set trust_env parameter in aiohttp client for Async and request client for Sync * handle trust_env and set trust_env parameter in aiohttp client for Async and request client for Sync * handle trust_env and set trust_env parameter in aiohttp client for Async and request client for Sync * Update src/huggingface_hub/inference/_generated/_async_client.py Co-authored-by: Benjamin BERNARD <[email protected]> * do not modify InferenceClient * Add trust_env parameter only in AsyncInferenceClient + respect proxy * document proxies and trust_env parameters * remove newlines --------- Co-authored-by: Benjamin BERNARD <[email protected]> Co-authored-by: Lucain Pouget <[email protected]>
1 parent 9a9b8c1 commit c9c39b8

File tree

3 files changed

+89
-24
lines changed

3 files changed

+89
-24
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ class InferenceClient:
148148
Values in this dictionary will override the default values.
149149
cookies (`Dict[str, str]`, `optional`):
150150
Additional cookies to send to the server.
151+
proxies (`Any`, `optional`):
152+
Proxies to use for the request.
151153
base_url (`str`, `optional`):
152154
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
153155
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696

9797
if TYPE_CHECKING:
9898
import numpy as np
99+
from aiohttp import ClientSession
99100
from PIL.Image import Image
100101

101102
logger = logging.getLogger(__name__)
@@ -133,6 +134,10 @@ class AsyncInferenceClient:
133134
Values in this dictionary will override the default values.
134135
cookies (`Dict[str, str]`, `optional`):
135136
Additional cookies to send to the server.
137+
trust_env ('bool', 'optional'):
138+
Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
139+
proxies (`Any`, `optional`):
140+
Proxies to use for the request.
136141
base_url (`str`, `optional`):
137142
Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
138143
follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
@@ -150,6 +155,7 @@ def __init__(
150155
timeout: Optional[float] = None,
151156
headers: Optional[Dict[str, str]] = None,
152157
cookies: Optional[Dict[str, str]] = None,
158+
trust_env: bool = False,
153159
proxies: Optional[Any] = None,
154160
# OpenAI compatibility
155161
base_url: Optional[str] = None,
@@ -176,6 +182,7 @@ def __init__(
176182
self.headers.update(headers)
177183
self.cookies = cookies
178184
self.timeout = timeout
185+
self.trust_env = trust_env
179186
self.proxies = proxies
180187

181188
# OpenAI compatibility
@@ -265,7 +272,7 @@ async def post(
265272
warnings.warn("Ignoring `json` as `data` is passed as binary.")
266273

267274
# Set Accept header if relevant
268-
headers = self.headers.copy()
275+
headers = dict()
269276
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
270277
headers["Accept"] = "image/png"
271278

@@ -275,9 +282,7 @@ async def post(
275282
with _open_as_binary(data) as data_as_binary:
276283
# Do not use context manager as we don't want to close the connection immediately when returning
277284
# a stream
278-
client = aiohttp.ClientSession(
279-
headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
280-
)
285+
client = self._get_client_session(headers=headers)
281286

282287
try:
283288
response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
@@ -1299,8 +1304,8 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:
12991304
models_by_task.setdefault(model["task"], []).append(model["model_id"])
13001305

13011306
async def _fetch_framework(framework: str) -> None:
1302-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
1303-
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
1307+
async with self._get_client_session() as client:
1308+
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
13041309
response.raise_for_status()
13051310
_unpack_response(framework, await response.json())
13061311

@@ -2581,6 +2586,20 @@ async def zero_shot_image_classification(
25812586
)
25822587
return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response)
25832588

2589+
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
2590+
aiohttp = _import_aiohttp()
2591+
client_headers = self.headers.copy()
2592+
if headers is not None:
2593+
client_headers.update(headers)
2594+
2595+
# Return a new aiohttp ClientSession with correct settings.
2596+
return aiohttp.ClientSession(
2597+
headers=client_headers,
2598+
cookies=self.cookies,
2599+
timeout=aiohttp.ClientTimeout(self.timeout),
2600+
trust_env=self.trust_env,
2601+
)
2602+
25842603
def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
25852604
model = model or self.model or self.base_url
25862605

@@ -2687,8 +2706,8 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A
26872706
else:
26882707
url = f"{INFERENCE_ENDPOINT}/models/{model}/info"
26892708

2690-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2691-
response = await client.get(url)
2709+
async with self._get_client_session() as client:
2710+
response = await client.get(url, proxy=self.proxies)
26922711
response.raise_for_status()
26932712
return await response.json()
26942713

@@ -2724,8 +2743,8 @@ async def health_check(self, model: Optional[str] = None) -> bool:
27242743
)
27252744
url = model.rstrip("/") + "/health"
27262745

2727-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2728-
response = await client.get(url)
2746+
async with self._get_client_session() as client:
2747+
response = await client.get(url, proxy=self.proxies)
27292748
return response.status == 200
27302749

27312750
async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
@@ -2766,8 +2785,8 @@ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
27662785
raise NotImplementedError("Model status is only available for Inference API endpoints.")
27672786
url = f"{INFERENCE_ENDPOINT}/status/{model}"
27682787

2769-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
2770-
response = await client.get(url)
2788+
async with self._get_client_session() as client:
2789+
response = await client.get(url, proxy=self.proxies)
27712790
response.raise_for_status()
27722791
response_data = await response.json()
27732792

utils/generate_async_inference_client.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def generate_async_client_code(code: str) -> str:
6868
# Adapt /info and /health endpoints
6969
code = _adapt_info_and_health_endpoints(code)
7070

71+
# Add _get_client_session
72+
code = _add_get_client_session(code)
73+
7174
# Adapt the proxy client (for client.chat.completions.create)
7275
code = _adapt_proxy_client(code)
7376

@@ -186,7 +189,7 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
186189
warnings.warn("Ignoring `json` as `data` is passed as binary.")
187190
188191
# Set Accept header if relevant
189-
headers = self.headers.copy()
192+
headers = dict()
190193
if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
191194
headers["Accept"] = "image/png"
192195
@@ -196,9 +199,7 @@ def _rename_to_AsyncInferenceClient(code: str) -> str:
196199
with _open_as_binary(data) as data_as_binary:
197200
# Do not use context manager as we don't want to close the connection immediately when returning
198201
# a stream
199-
client = aiohttp.ClientSession(
200-
headers=headers, cookies=self.cookies, timeout=aiohttp.ClientTimeout(self.timeout)
201-
)
202+
client = self._get_client_session(headers=headers)
202203
203204
try:
204205
response = await client.post(url, json=json, data=data_as_binary, proxy=self.proxies)
@@ -420,8 +421,8 @@ def _adapt_get_model_status(code: str) -> str:
420421
response_data = response.json()"""
421422

422423
async_snippet = """
423-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
424-
response = await client.get(url)
424+
async with self._get_client_session() as client:
425+
response = await client.get(url, proxy=self.proxies)
425426
response.raise_for_status()
426427
response_data = await response.json()"""
427428

@@ -437,8 +438,8 @@ def _adapt_list_deployed_models(code: str) -> str:
437438

438439
async_snippet = """
439440
async def _fetch_framework(framework: str) -> None:
440-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
441-
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
441+
async with self._get_client_session() as client:
442+
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies)
442443
response.raise_for_status()
443444
_unpack_response(framework, await response.json())
444445
@@ -456,8 +457,8 @@ def _adapt_info_and_health_endpoints(code: str) -> str:
456457
return response.json()"""
457458

458459
info_async_snippet = """
459-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
460-
response = await client.get(url)
460+
async with self._get_client_session() as client:
461+
response = await client.get(url, proxy=self.proxies)
461462
response.raise_for_status()
462463
return await response.json()"""
463464

@@ -468,20 +469,63 @@ def _adapt_info_and_health_endpoints(code: str) -> str:
468469
return response.status_code == 200"""
469470

470471
health_async_snippet = """
471-
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
472-
response = await client.get(url)
472+
async with self._get_client_session() as client:
473+
response = await client.get(url, proxy=self.proxies)
473474
return response.status == 200"""
474475

475476
return code.replace(health_sync_snippet, health_async_snippet)
476477

477478

479+
def _add_get_client_session(code: str) -> str:
480+
# Add trust_env as parameter
481+
code = _add_before(code, "proxies: Optional[Any] = None,", "trust_env: bool = False,")
482+
code = _add_before(code, "\n self.proxies = proxies\n", "\n self.trust_env = trust_env")
483+
484+
# Document `trust_env` parameter
485+
code = _add_before(
486+
code,
487+
"\n proxies (`Any`, `optional`):",
488+
"""
489+
trust_env ('bool', 'optional'):
490+
Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).""",
491+
)
492+
493+
# insert `_get_client_session` before `_resolve_url` method
494+
client_session_code = """
495+
496+
def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession":
497+
aiohttp = _import_aiohttp()
498+
client_headers = self.headers.copy()
499+
if headers is not None:
500+
client_headers.update(headers)
501+
502+
# Return a new aiohttp ClientSession with correct settings.
503+
return aiohttp.ClientSession(
504+
headers=client_headers,
505+
cookies=self.cookies,
506+
timeout=aiohttp.ClientTimeout(self.timeout),
507+
trust_env=self.trust_env,
508+
)
509+
510+
"""
511+
code = _add_before(code, "\n def _resolve_url(", client_session_code)
512+
513+
return code
514+
515+
478516
def _adapt_proxy_client(code: str) -> str:
479517
return code.replace(
480518
"def __init__(self, client: InferenceClient):",
481519
"def __init__(self, client: AsyncInferenceClient):",
482520
)
483521

484522

523+
def _add_before(code: str, pattern: str, addition: str) -> str:
524+
index = code.find(pattern)
525+
assert index != -1, f"Pattern '{pattern}' not found in code."
526+
return code[:index] + addition + code[index:]
527+
528+
485529
if __name__ == "__main__":
486530
parser = argparse.ArgumentParser()
487531
parser.add_argument(

0 commit comments

Comments
 (0)