Skip to content

Commit e1162cd

Browse files
committed
Avoid redundant call to the Xet connection info URL (#3534)
* Avoid redundant call to the Xet connection info URL * authorization casing * remove log * thread safety
1 parent 4f94171 commit e1162cd

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/huggingface_hub/utils/_xet.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
from dataclasses import dataclass
23
from enum import Enum
34
from typing import Optional
@@ -8,6 +9,11 @@
89
from . import get_session, hf_raise_for_status, validate_hf_hub_args
910

1011

12+
XET_CONNECTION_INFO_SAFETY_PERIOD = 60 # seconds
13+
XET_CONNECTION_INFO_CACHE_SIZE = 1_000
14+
XET_CONNECTION_INFO_CACHE: dict[str, "XetConnectionInfo"] = {}
15+
16+
1117
class XetTokenType(str, Enum):
1218
READ = "read"
1319
WRITE = "write"
@@ -167,6 +173,9 @@ def _fetch_xet_connection_info_with_url(
167173
"""
168174
Requests the xet connection info from the supplied URL. This includes the
169175
access token, expiration time, and endpoint to use for the xet storage service.
176+
177+
Result is cached to avoid redundant requests.
178+
170179
Args:
171180
url: (`str`):
172181
The access token endpoint URL.
@@ -183,10 +192,44 @@ def _fetch_xet_connection_info_with_url(
183192
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
184193
If the Hub API response is improperly formatted.
185194
"""
195+
# Check cache first
196+
cache_key = _cache_key(url, headers, params)
197+
cached_info = XET_CONNECTION_INFO_CACHE.get(cache_key)
198+
if cached_info is not None:
199+
if not _is_expired(cached_info):
200+
return cached_info
201+
202+
# Fetch from server
186203
resp = get_session().get(headers=headers, url=url, params=params)
187204
hf_raise_for_status(resp)
188205

189206
metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
190207
if metadata is None:
191208
raise ValueError("Xet headers have not been correctly set by the server.")
209+
210+
# Delete expired cache entries
211+
for k, v in list(XET_CONNECTION_INFO_CACHE.items()):
212+
if _is_expired(v):
213+
XET_CONNECTION_INFO_CACHE.pop(k, None)
214+
215+
# Enforce cache size limit
216+
if len(XET_CONNECTION_INFO_CACHE) >= XET_CONNECTION_INFO_CACHE_SIZE:
217+
XET_CONNECTION_INFO_CACHE.pop(next(iter(XET_CONNECTION_INFO_CACHE)))
218+
219+
# Update cache
220+
XET_CONNECTION_INFO_CACHE[cache_key] = metadata
221+
192222
return metadata
223+
224+
225+
def _cache_key(url: str, headers: dict[str, str], params: Optional[dict[str, str]]) -> str:
226+
"""Return a unique cache key for the given request parameters."""
227+
lower_headers = {k.lower(): v for k, v in headers.items()} # casing is not guaranteed here
228+
auth_header = lower_headers.get("authorization", "")
229+
params_str = "&".join(f"{k}={v}" for k, v in sorted((params or {}).items(), key=lambda x: x[0]))
230+
return f"{url}|{auth_header}|{params_str}"
231+
232+
233+
def _is_expired(connection_info: XetConnectionInfo) -> bool:
234+
"""Check if the given XET connection info is expired."""
235+
return connection_info.expiration_unix_epoch <= int(time.time()) + XET_CONNECTION_INFO_SAFETY_PERIOD

0 commit comments

Comments
 (0)