1+ import time
12from dataclasses import dataclass
23from enum import Enum
34from typing import Optional
89from . import get_session , hf_raise_for_status , validate_hf_hub_args
910
1011
12+ XET_CONNECTION_INFO_SAFETY_PERIOD = 60 # seconds
13+ XET_CONNECTION_INFO_CACHE_SIZE = 1_000
14+ XET_CONNECTION_INFO_CACHE : dict [str , "XetConnectionInfo" ] = {}
15+
16+
1117class XetTokenType (str , Enum ):
1218 READ = "read"
1319 WRITE = "write"
@@ -167,6 +173,9 @@ def _fetch_xet_connection_info_with_url(
167173 """
168174 Requests the xet connection info from the supplied URL. This includes the
169175 access token, expiration time, and endpoint to use for the xet storage service.
176+
177+ Result is cached to avoid redundant requests.
178+
170179 Args:
171180 url: (`str`):
172181 The access token endpoint URL.
@@ -183,10 +192,44 @@ def _fetch_xet_connection_info_with_url(
183192 [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
184193 If the Hub API response is improperly formatted.
185194 """
195+ # Check cache first
196+ cache_key = _cache_key (url , headers , params )
197+ cached_info = XET_CONNECTION_INFO_CACHE .get (cache_key )
198+ if cached_info is not None :
199+ if not _is_expired (cached_info ):
200+ return cached_info
201+
202+ # Fetch from server
186203 resp = get_session ().get (headers = headers , url = url , params = params )
187204 hf_raise_for_status (resp )
188205
189206 metadata = parse_xet_connection_info_from_headers (resp .headers ) # type: ignore
190207 if metadata is None :
191208 raise ValueError ("Xet headers have not been correctly set by the server." )
209+
210+ # Delete expired cache entries
211+ for k , v in list (XET_CONNECTION_INFO_CACHE .items ()):
212+ if _is_expired (v ):
213+ XET_CONNECTION_INFO_CACHE .pop (k , None )
214+
215+ # Enforce cache size limit
216+ if len (XET_CONNECTION_INFO_CACHE ) >= XET_CONNECTION_INFO_CACHE_SIZE :
217+ XET_CONNECTION_INFO_CACHE .pop (next (iter (XET_CONNECTION_INFO_CACHE )))
218+
219+ # Update cache
220+ XET_CONNECTION_INFO_CACHE [cache_key ] = metadata
221+
192222 return metadata
223+
224+
225+ def _cache_key (url : str , headers : dict [str , str ], params : Optional [dict [str , str ]]) -> str :
226+ """Return a unique cache key for the given request parameters."""
227+ lower_headers = {k .lower (): v for k , v in headers .items ()} # casing is not guaranteed here
228+ auth_header = lower_headers .get ("authorization" , "" )
229+ params_str = "&" .join (f"{ k } ={ v } " for k , v in sorted ((params or {}).items (), key = lambda x : x [0 ]))
230+ return f"{ url } |{ auth_header } |{ params_str } "
231+
232+
233+ def _is_expired (connection_info : XetConnectionInfo ) -> bool :
234+ """Check if the given XET connection info is expired."""
235+ return connection_info .expiration_unix_epoch <= int (time .time ()) + XET_CONNECTION_INFO_SAFETY_PERIOD
0 commit comments