9696
9797if TYPE_CHECKING :
9898 import numpy as np
99+ from aiohttp import ClientSession
99100 from PIL .Image import Image
100101
101102logger = logging .getLogger (__name__ )
@@ -133,6 +134,10 @@ class AsyncInferenceClient:
133134 Values in this dictionary will override the default values.
134135 cookies (`Dict[str, str]`, `optional`):
135136 Additional cookies to send to the server.
137+ trust_env ('bool', 'optional'):
138+ Trust environment settings for proxy configuration if the parameter is `True` (`False` by default).
139+ proxies (`Any`, `optional`):
140+ Proxies to use for the request.
136141 base_url (`str`, `optional`):
137142 Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`]
138143 follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None.
@@ -150,6 +155,7 @@ def __init__(
150155 timeout : Optional [float ] = None ,
151156 headers : Optional [Dict [str , str ]] = None ,
152157 cookies : Optional [Dict [str , str ]] = None ,
158+ trust_env : bool = False ,
153159 proxies : Optional [Any ] = None ,
154160 # OpenAI compatibility
155161 base_url : Optional [str ] = None ,
@@ -176,6 +182,7 @@ def __init__(
176182 self .headers .update (headers )
177183 self .cookies = cookies
178184 self .timeout = timeout
185+ self .trust_env = trust_env
179186 self .proxies = proxies
180187
181188 # OpenAI compatibility
@@ -265,7 +272,7 @@ async def post(
265272 warnings .warn ("Ignoring `json` as `data` is passed as binary." )
266273
267274 # Set Accept header if relevant
268- headers = self . headers . copy ()
275+ headers = dict ()
269276 if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers :
270277 headers ["Accept" ] = "image/png"
271278
@@ -275,9 +282,7 @@ async def post(
275282 with _open_as_binary (data ) as data_as_binary :
276283 # Do not use context manager as we don't want to close the connection immediately when returning
277284 # a stream
278- client = aiohttp .ClientSession (
279- headers = headers , cookies = self .cookies , timeout = aiohttp .ClientTimeout (self .timeout )
280- )
285+ client = self ._get_client_session (headers = headers )
281286
282287 try :
283288 response = await client .post (url , json = json , data = data_as_binary , proxy = self .proxies )
@@ -1299,8 +1304,8 @@ def _unpack_response(framework: str, items: List[Dict]) -> None:
12991304 models_by_task .setdefault (model ["task" ], []).append (model ["model_id" ])
13001305
13011306 async def _fetch_framework (framework : str ) -> None :
1302- async with _import_aiohttp (). ClientSession ( headers = self .headers ) as client :
1303- response = await client .get (f"{ INFERENCE_ENDPOINT } /framework/{ framework } " )
1307+ async with self ._get_client_session ( ) as client :
1308+ response = await client .get (f"{ INFERENCE_ENDPOINT } /framework/{ framework } " , proxy = self . proxies )
13041309 response .raise_for_status ()
13051310 _unpack_response (framework , await response .json ())
13061311
@@ -2581,6 +2586,20 @@ async def zero_shot_image_classification(
25812586 )
25822587 return ZeroShotImageClassificationOutputElement .parse_obj_as_list (response )
25832588
2589+ def _get_client_session (self , headers : Optional [Dict ] = None ) -> "ClientSession" :
2590+ aiohttp = _import_aiohttp ()
2591+ client_headers = self .headers .copy ()
2592+ if headers is not None :
2593+ client_headers .update (headers )
2594+
2595+ # Return a new aiohttp ClientSession with correct settings.
2596+ return aiohttp .ClientSession (
2597+ headers = client_headers ,
2598+ cookies = self .cookies ,
2599+ timeout = aiohttp .ClientTimeout (self .timeout ),
2600+ trust_env = self .trust_env ,
2601+ )
2602+
25842603 def _resolve_url (self , model : Optional [str ] = None , task : Optional [str ] = None ) -> str :
25852604 model = model or self .model or self .base_url
25862605
@@ -2687,8 +2706,8 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A
26872706 else :
26882707 url = f"{ INFERENCE_ENDPOINT } /models/{ model } /info"
26892708
2690- async with _import_aiohttp (). ClientSession ( headers = self .headers ) as client :
2691- response = await client .get (url )
2709+ async with self ._get_client_session ( ) as client :
2710+ response = await client .get (url , proxy = self . proxies )
26922711 response .raise_for_status ()
26932712 return await response .json ()
26942713
@@ -2724,8 +2743,8 @@ async def health_check(self, model: Optional[str] = None) -> bool:
27242743 )
27252744 url = model .rstrip ("/" ) + "/health"
27262745
2727- async with _import_aiohttp (). ClientSession ( headers = self .headers ) as client :
2728- response = await client .get (url )
2746+ async with self ._get_client_session ( ) as client :
2747+ response = await client .get (url , proxy = self . proxies )
27292748 return response .status == 200
27302749
27312750 async def get_model_status (self , model : Optional [str ] = None ) -> ModelStatus :
@@ -2766,8 +2785,8 @@ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
27662785 raise NotImplementedError ("Model status is only available for Inference API endpoints." )
27672786 url = f"{ INFERENCE_ENDPOINT } /status/{ model } "
27682787
2769- async with _import_aiohttp (). ClientSession ( headers = self .headers ) as client :
2770- response = await client .get (url )
2788+ async with self ._get_client_session ( ) as client :
2789+ response = await client .get (url , proxy = self . proxies )
27712790 response .raise_for_status ()
27722791 response_data = await response .json ()
27732792
0 commit comments