1010from enum import IntEnum
1111from json import loads
1212from os import environ
13+ from urllib .parse import parse_qsl , urlencode , urlsplit , urlunsplit
1314
14- from httpx import AsyncClient , Client , Headers , Limits , ReadTimeout , Request , Response
15- from httpx import __version__ as httpx_version
15+ from niquests import AsyncSession , ReadTimeout , Request , Response , Session
16+ from niquests import __version__ as niquests_version
17+ from niquests .structures import CaseInsensitiveDict
1618from starlette .requests import HTTPConnection
1719
1820from . import options
@@ -49,6 +51,13 @@ class ServerVersion(typing.TypedDict):
4951 """Indicates if the subscription has extended support"""
5052
5153
54+ @dataclass
55+ class Limits :
56+ max_keepalive_connections : int | None = 20
57+ max_connections : int | None = 100
58+ keepalive_expiry : int | float | None = 5
59+
60+
5261@dataclass
5362class RuntimeOptions :
5463 xdebug_session : str
@@ -134,11 +143,11 @@ def __init__(self, **kwargs):
134143
135144
136145class NcSessionBase (ABC ):
137- adapter : AsyncClient | Client
138- adapter_dav : AsyncClient | Client
146+ adapter : AsyncSession | Session
147+ adapter_dav : AsyncSession | Session
139148 cfg : BasicConfig
140149 custom_headers : dict
141- response_headers : Headers
150+ response_headers : CaseInsensitiveDict
142151 _user : str
143152 _capabilities : dict
144153
@@ -150,7 +159,7 @@ def __init__(self, **kwargs):
150159 self .limits = Limits (max_keepalive_connections = 20 , max_connections = 20 , keepalive_expiry = 60.0 )
151160 self .init_adapter ()
152161 self .init_adapter_dav ()
153- self .response_headers = Headers ()
162+ self .response_headers = CaseInsensitiveDict ()
154163 self ._ocs_regexp = re .compile (r"/ocs/v[12]\.php/|/apps/groupfolders/" )
155164
156165 def init_adapter (self , restart = False ) -> None :
@@ -172,7 +181,7 @@ def init_adapter_dav(self, restart=False) -> None:
172181 self .adapter_dav .cookies .set ("XDEBUG_SESSION" , options .XDEBUG_SESSION )
173182
174183 @abstractmethod
175- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
184+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
176185 pass # pragma: no cover
177186
178187 @property
@@ -187,8 +196,8 @@ def ae_url_v2(self) -> str:
187196
188197
189198class NcSessionBasic (NcSessionBase , ABC ):
190- adapter : Client
191- adapter_dav : Client
199+ adapter : Session
200+ adapter_dav : Session
192201
193202 def ocs (
194203 self ,
@@ -206,9 +215,7 @@ def ocs(
206215 info = f"request: { method } { path } "
207216 nested_req = kwargs .pop ("nested_req" , False )
208217 try :
209- response = self .adapter .request (
210- method , path , content = content , json = json , params = params , files = files , ** kwargs
211- )
218+ response = self .adapter .request (method , path , data = content , json = json , params = params , files = files , ** kwargs )
212219 except ReadTimeout :
213220 raise NextcloudException (408 , info = info ) from None
214221
@@ -281,18 +288,18 @@ def _get_adapter_kwargs(self, dav: bool) -> dict[str, typing.Any]:
281288 return {
282289 "base_url" : self .cfg .dav_endpoint ,
283290 "timeout" : self .cfg .options .timeout_dav ,
284- "event_hooks" : {"request " : [], "response" : [self ._response_event ]},
291+ "event_hooks" : {"pre_request " : [], "response" : [self ._response_event ]},
285292 }
286293 return {
287294 "base_url" : self .cfg .endpoint ,
288295 "timeout" : self .cfg .options .timeout ,
289- "event_hooks" : {"request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
296+ "event_hooks" : {"pre_request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
290297 }
291298
292299 def _request_event_ocs (self , request : Request ) -> None :
293300 str_url = str (request .url )
294301 if re .search (self ._ocs_regexp , str_url ) is not None : # this is OCS call
295- request .url = request .url . copy_merge_params ({ "format" : "json" } )
302+ request .url = patch_param ( request .url , "format" , "json" )
296303 request .headers ["Accept" ] = "application/json"
297304
298305 def _response_event (self , response : Response ) -> None :
@@ -305,15 +312,15 @@ def _response_event(self, response: Response) -> None:
305312
306313 def download2fp (self , url_path : str , fp , dav : bool , params = None , ** kwargs ):
307314 adapter = self .adapter_dav if dav else self .adapter
308- with adapter .stream ( "GET" , url_path , params = params , headers = kwargs .get ("headers" )) as response :
315+ with adapter .get ( url_path , params = params , headers = kwargs .get ("headers" ), stream = True ) as response :
309316 check_error (response )
310- for data_chunk in response .iter_bytes (chunk_size = kwargs .get ("chunk_size" , 5 * 1024 * 1024 )):
317+ for data_chunk in response .iter_raw (chunk_size = kwargs .get ("chunk_size" , - 1 )):
311318 fp .write (data_chunk )
312319
313320
314321class AsyncNcSessionBasic (NcSessionBase , ABC ):
315- adapter : AsyncClient
316- adapter_dav : AsyncClient
322+ adapter : AsyncSession
323+ adapter_dav : AsyncSession
317324
318325 async def ocs (
319326 self ,
@@ -332,7 +339,7 @@ async def ocs(
332339 nested_req = kwargs .pop ("nested_req" , False )
333340 try :
334341 response = await self .adapter .request (
335- method , path , content = content , json = json , params = params , files = files , ** kwargs
342+ method , path , data = content , json = json , params = params , files = files , ** kwargs
336343 )
337344 except ReadTimeout :
338345 raise NextcloudException (408 , info = info ) from None
@@ -350,7 +357,7 @@ async def ocs(
350357 and ocs_meta ["statuscode" ] == 403
351358 and str (ocs_meta ["message" ]).lower ().find ("password confirmation is required" ) != - 1
352359 ):
353- await self .adapter .aclose ()
360+ await self .adapter .close ()
354361 self .init_adapter (restart = True )
355362 return await self .ocs (
356363 method , path , ** kwargs , content = content , json = json , params = params , nested_req = True
@@ -408,18 +415,18 @@ def _get_adapter_kwargs(self, dav: bool) -> dict[str, typing.Any]:
408415 return {
409416 "base_url" : self .cfg .dav_endpoint ,
410417 "timeout" : self .cfg .options .timeout_dav ,
411- "event_hooks" : {"request " : [], "response" : [self ._response_event ]},
418+ "event_hooks" : {"pre_request " : [], "response" : [self ._response_event ]},
412419 }
413420 return {
414421 "base_url" : self .cfg .endpoint ,
415422 "timeout" : self .cfg .options .timeout ,
416- "event_hooks" : {"request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
423+ "event_hooks" : {"pre_request " : [self ._request_event_ocs ], "response" : [self ._response_event ]},
417424 }
418425
419426 async def _request_event_ocs (self , request : Request ) -> None :
420427 str_url = str (request .url )
421428 if re .search (self ._ocs_regexp , str_url ) is not None : # this is OCS call
422- request .url = request .url . copy_merge_params ({ "format" : "json" } )
429+ request .url = patch_param ( request .url , "format" , "json" )
423430 request .headers ["Accept" ] = "application/json"
424431
425432 async def _response_event (self , response : Response ) -> None :
@@ -432,10 +439,12 @@ async def _response_event(self, response: Response) -> None:
432439
433440 async def download2fp (self , url_path : str , fp , dav : bool , params = None , ** kwargs ):
434441 adapter = self .adapter_dav if dav else self .adapter
435- async with adapter .stream ("GET" , url_path , params = params , headers = kwargs .get ("headers" )) as response :
436- check_error (response )
437- async for data_chunk in response .aiter_bytes (chunk_size = kwargs .get ("chunk_size" , 5 * 1024 * 1024 )):
438- fp .write (data_chunk )
442+ response = await adapter .get (url_path , params = params , headers = kwargs .get ("headers" ), stream = True )
443+
444+ check_error (response )
445+
446+ async for data_chunk in await response .iter_raw (chunk_size = kwargs .get ("chunk_size" , - 1 )):
447+ fp .write (data_chunk )
439448
440449
441450class NcSession (NcSessionBasic ):
@@ -445,15 +454,20 @@ def __init__(self, **kwargs):
445454 self .cfg = Config (** kwargs )
446455 super ().__init__ ()
447456
448- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
449- return Client (
450- follow_redirects = True ,
451- limits = self .limits ,
452- verify = self .cfg .options .nc_cert ,
453- ** self ._get_adapter_kwargs (dav ),
454- auth = self .cfg .auth ,
457+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
458+ session_kwargs = self ._get_adapter_kwargs (dav )
459+ hooks = session_kwargs .pop ("event_hooks" )
460+
461+ session = Session (
462+ keepalive_delay = self .limits .keepalive_expiry , pool_maxsize = self .limits .max_connections , ** session_kwargs
455463 )
456464
465+ session .auth = self .cfg .auth
466+ session .verify = self .cfg .options .nc_cert
467+ session .hooks .update (hooks )
468+
469+ return session
470+
457471
458472class AsyncNcSession (AsyncNcSessionBasic ):
459473 cfg : Config
@@ -462,21 +476,28 @@ def __init__(self, **kwargs):
462476 self .cfg = Config (** kwargs )
463477 super ().__init__ ()
464478
465- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
466- return AsyncClient (
467- follow_redirects = True ,
468- limits = self .limits ,
469- verify = self .cfg .options .nc_cert ,
470- ** self ._get_adapter_kwargs (dav ),
471- auth = self .cfg .auth ,
479+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
480+ session_kwargs = self ._get_adapter_kwargs (dav )
481+ hooks = session_kwargs .pop ("event_hooks" )
482+
483+ session = AsyncSession (
484+ keepalive_delay = self .limits .keepalive_expiry ,
485+ pool_maxsize = self .limits .max_connections ,
486+ ** session_kwargs ,
472487 )
473488
489+ session .verify = self .cfg .options .nc_cert
490+ session .auth = self .cfg .auth
491+ session .hooks .update (hooks )
492+
493+ return session
494+
474495
475496class NcSessionAppBasic (ABC ):
476497 cfg : AppConfig
477498 _user : str
478- adapter : AsyncClient | Client
479- adapter_dav : AsyncClient | Client
499+ adapter : AsyncSession | Session
500+ adapter_dav : AsyncSession | Session
480501
481502 def __init__ (self , ** kwargs ):
482503 self .cfg = AppConfig (** kwargs )
@@ -505,22 +526,29 @@ def sign_check(self, request: HTTPConnection) -> str:
505526class NcSessionApp (NcSessionAppBasic , NcSessionBasic ):
506527 cfg : AppConfig
507528
508- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
509- r = self ._get_adapter_kwargs (dav )
510- r ["event_hooks" ]["request" ].append (self ._add_auth )
511- return Client (
512- follow_redirects = True ,
513- limits = self .limits ,
514- verify = self .cfg .options .nc_cert ,
515- ** r ,
516- headers = {
517- "AA-VERSION" : self .cfg .aa_version ,
518- "EX-APP-ID" : self .cfg .app_name ,
519- "EX-APP-VERSION" : self .cfg .app_version ,
520- "user-agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (httpx/{ httpx_version } )" ,
521- },
529+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
530+ session_kwargs = self ._get_adapter_kwargs (dav )
531+ session_kwargs ["event_hooks" ]["pre_request" ].append (self ._add_auth )
532+
533+ hooks = session_kwargs .pop ("event_hooks" )
534+
535+ session = Session (
536+ keepalive_delay = self .limits .keepalive_expiry ,
537+ pool_maxsize = self .limits .max_connections ,
538+ ** session_kwargs ,
522539 )
523540
541+ session .verify = self .cfg .options .nc_cert
542+ session .headers = {
543+ "AA-VERSION" : self .cfg .aa_version ,
544+ "EX-APP-ID" : self .cfg .app_name ,
545+ "EX-APP-VERSION" : self .cfg .app_version ,
546+ "user-agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (niquests/{ niquests_version } )" ,
547+ }
548+ session .hooks .update (hooks )
549+
550+ return session
551+
524552 def _add_auth (self , request : Request ):
525553 request .headers .update (
526554 {"AUTHORIZATION-APP-API" : b64encode (f"{ self ._user } :{ self .cfg .app_secret } " .encode ("UTF=8" ))}
@@ -530,23 +558,39 @@ def _add_auth(self, request: Request):
530558class AsyncNcSessionApp (NcSessionAppBasic , AsyncNcSessionBasic ):
531559 cfg : AppConfig
532560
533- def _create_adapter (self , dav : bool = False ) -> AsyncClient | Client :
534- r = self ._get_adapter_kwargs (dav )
535- r ["event_hooks" ]["request" ].append (self ._add_auth )
536- return AsyncClient (
537- follow_redirects = True ,
538- limits = self .limits ,
539- verify = self .cfg .options .nc_cert ,
540- ** r ,
541- headers = {
542- "AA-VERSION" : self .cfg .aa_version ,
543- "EX-APP-ID" : self .cfg .app_name ,
544- "EX-APP-VERSION" : self .cfg .app_version ,
545- "User-Agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (httpx/{ httpx_version } )" ,
546- },
561+ def _create_adapter (self , dav : bool = False ) -> AsyncSession | Session :
562+ session_kwargs = self ._get_adapter_kwargs (dav )
563+ session_kwargs ["event_hooks" ]["pre_request" ].append (self ._add_auth )
564+
565+ hooks = session_kwargs .pop ("event_hooks" )
566+
567+ session = AsyncSession (
568+ keepalive_delay = self .limits .keepalive_expiry ,
569+ pool_maxsize = self .limits .max_connections ,
570+ ** session_kwargs ,
547571 )
572+ session .verify = self .cfg .options .nc_cert
573+ session .headers = {
574+ "AA-VERSION" : self .cfg .aa_version ,
575+ "EX-APP-ID" : self .cfg .app_name ,
576+ "EX-APP-VERSION" : self .cfg .app_version ,
577+ "User-Agent" : f"ExApp/{ self .cfg .app_name } /{ self .cfg .app_version } (niquests/{ niquests_version } )" ,
578+ }
579+ session .hooks .update (hooks )
580+
581+ return session
548582
549583 async def _add_auth (self , request : Request ):
550584 request .headers .update (
551585 {"AUTHORIZATION-APP-API" : b64encode (f"{ self ._user } :{ self .cfg .app_secret } " .encode ("UTF=8" ))}
552586 )
587+
588+
589+ def patch_param (url : str , key : str , value : str ) -> str :
590+ parts = urlsplit (url )
591+ query = dict (parse_qsl (parts .query , keep_blank_values = True ))
592+ query [key ] = value
593+
594+ new_query = urlencode (query , doseq = True )
595+
596+ return urlunsplit ((parts .scheme , parts .netloc , parts .path , new_query , parts .fragment ))
0 commit comments