11from getpass import getpass
2- import requests
32from abc import ABC
43import base64
54import json
65import time
7- from copy import copy , deepcopy
6+ from copy import copy
7+ from typing import Callable
8+ from functools import wraps
9+
10+ import requests
11+
812from ebrains_drive .utils import on_401_raise_unauthorized
913from ebrains_drive .exceptions import ClientHttpError , TokenExpired , Unauthorized
1014from ebrains_drive .repos import Repos
@@ -71,18 +75,6 @@ def put(self, *args, **kwargs):
7175 def delete (self , * args , ** kwargs ):
7276 return self .send_request ("DELETE" , * args , ** kwargs )
7377
74- def _exchange_oidc_for_seafile_token (self ):
75- url = self .server .rstrip ("/" ) + "/api2/account/token/"
76- headers = {"Authorization" : f"Bearer { self ._token } " }
77-
78- resp = self .session .get (url , headers = headers )
79-
80- if resp .status_code != 200 :
81- raise Exception (f"Failed to exchange OIDC token for Seafile token: { resp .status_code } { resp .text } " )
82-
83- self ._seafile_token = resp .text .strip ()
84- return self ._seafile_token
85-
8678 def send_request (self , method : str , url : str , * args , ** kwargs ):
8779 if not url .startswith ("http" ):
8880 # sanity checks.
@@ -94,31 +86,62 @@ def send_request(self, method: str, url: str, *args, **kwargs):
9486 # We cannot deepcopy the whole thing, because some values (e.g. BufferedReader objects)
9587 # cannot be pickled
9688 kwargs = copy (kwargs )
97- headers = kwargs .pop ("headers" , {}).copy ()
89+ headers : dict = kwargs .pop ("headers" , {}).copy ()
90+ token_auth = kwargs .pop ("token_auth" , None )
9891
99- if self ._seafile_token :
100- headers .setdefault ("Authorization" , "Token " + self ._seafile_token )
101- else :
102- headers .setdefault ("Authorization" , "Bearer " + self ._token )
92+ auth_header = f"Token { token_auth } " if token_auth else f"Bearer { self ._token } "
93+ headers .setdefault ("Authorization" , auth_header )
10394
10495 expected = kwargs .pop ("expected" , 200 )
10596 if not hasattr (expected , "__iter__" ):
10697 expected = (expected ,)
10798
10899 resp = self .session .request (method , url , headers = headers , * args , ** kwargs )
109100
110- if resp .status_code == 401 and not self ._seafile_token :
111- self ._seafile_token = self ._exchange_oidc_for_seafile_token ()
112-
113- headers ["Authorization" ] = "Token " + self ._seafile_token
114- resp = self .session .request (method , url , headers = headers , * args , ** kwargs )
115-
116101 if resp .status_code not in expected :
117102 msg = f"Expected { expected } , but got { resp .status_code } "
118103 raise ClientHttpError (resp .status_code , msg )
119104
120105 return resp
121106
107+
108+ def wrap_exchange_seafile_token ():
109+ def exchange_oidc_for_seafile (self : "DriveApiClient" ):
110+
111+ url = self .server .rstrip ("/" ) + "/api2/account/token/"
112+ headers = {"Authorization" : f"Bearer { self ._token } " }
113+
114+ resp = self .session .get (url , headers = headers )
115+ resp .raise_for_status ()
116+
117+ return resp .text .strip ()
118+
119+ def outer (fn : Callable ):
120+ @wraps (fn )
121+ def inner (self , * args , ** kwargs ):
122+ assert isinstance (self , DriveApiClient ), f"seafile exchange can only decorate DriveApiClient"
123+
124+ kwargs = copy (kwargs )
125+
126+ if self ._seafile_token is None :
127+ self ._seafile_token = exchange_oidc_for_seafile (self )
128+
129+ retry_counter = 1
130+ while retry_counter >= 0 :
131+ try :
132+ kwargs ["token_auth" ] = self ._seafile_token
133+ return fn (self , * args , ** kwargs )
134+ except ClientHttpError as e :
135+ if e .code == 401 :
136+ self ._seafile_token = exchange_oidc_for_seafile (self )
137+ retry_counter -= 1
138+ continue
139+ raise e from e
140+
141+ return inner
142+ return outer
143+
144+
122145class DriveApiClient (ClientBase ):
123146 """Wraps seafile web api"""
124147
@@ -152,6 +175,7 @@ def __str__(self):
152175
153176 __repr__ = __str__
154177
178+ @wrap_exchange_seafile_token ()
155179 def send_request (self , method : str , url : str , * args , ** kwargs ):
156180 if not url .startswith ("http" ):
157181 assert not self .server .endswith ("/" )
@@ -162,7 +186,7 @@ def send_request(self, method: str, url: str, *args, **kwargs):
162186 return super ().send_request (method , url , * args , ** kwargs )
163187
164188
165- _I_AM_A_PUBLIC_BUCKET = "_I_AM_A_PUBLIC_BUCKET"
189+ _I_AM_A_PUBLIC_BUCKET = object ()
166190
167191
168192class BucketApiClient (ClientBase ):
@@ -235,7 +259,7 @@ def delete_bucket(self, bucket_name: str, *, delete_wiki=False):
235259
236260 def send_request (self , method : str , url : str , * args , ** kwargs ):
237261
238- if self ._token != _I_AM_A_PUBLIC_BUCKET :
262+ if self ._token is not _I_AM_A_PUBLIC_BUCKET :
239263 hdr , info , sig = self ._token .split ("." )
240264 info_json = base64 .b64decode (info + "==" ).decode ("utf-8" )
241265
@@ -246,7 +270,7 @@ def send_request(self, method: str, url: str, *args, **kwargs):
246270 if now_tc_seconds > exp_utc_seconds :
247271 raise TokenExpired
248272
249- if self ._token == _I_AM_A_PUBLIC_BUCKET :
273+ if self ._token is _I_AM_A_PUBLIC_BUCKET :
250274 headers = kwargs .get ("headers" , {})
251275 headers ["Authorization" ] = None
252276 kwargs ["headers" ] = headers
0 commit comments