Skip to content

Commit d3e7b6d

Browse files
committed
add backoff for download cubin files, and add number of retries
1 parent f130e55 commit d3e7b6d

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

flashinfer/artifacts.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .jit.core import logger
2727
from .jit.cubin_loader import (
2828
FLASHINFER_CUBINS_REPOSITORY,
29-
get_cubin,
29+
download_file,
3030
safe_urljoin,
3131
FLASHINFER_CUBIN_DIR,
3232
)
@@ -125,26 +125,31 @@ def download_artifacts() -> None:
125125
# HTTPS connections.
126126
session = requests.Session()
127127

128-
with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"):
129-
cubin_files = list(get_cubin_file_list())
130-
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
131-
with tqdm_logging_redirect(
132-
total=len(cubin_files), desc="Downloading cubins"
133-
) as pbar:
134-
135-
def update_pbar_cb(_) -> None:
136-
pbar.update(1)
137-
138-
with ThreadPoolExecutor(num_threads) as pool:
139-
futures = []
140-
for name in cubin_files:
141-
fut = pool.submit(get_cubin, name, "", session)
142-
fut.add_done_callback(update_pbar_cb)
143-
futures.append(fut)
144-
145-
results = [fut.result() for fut in as_completed(futures)]
146-
147-
all_success = all(results)
128+
cubin_files = list(get_cubin_file_list())
129+
num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4"))
130+
with tqdm_logging_redirect(
131+
total=len(cubin_files), desc="Downloading cubins"
132+
) as pbar:
133+
134+
def update_pbar_cb(_) -> None:
135+
pbar.update(1)
136+
137+
with ThreadPoolExecutor(num_threads) as pool:
138+
futures = []
139+
for name in cubin_files:
140+
source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name)
141+
local_path = FLASHINFER_CUBIN_DIR / name
142+
# Ensure parent directory exists
143+
local_path.parent.mkdir(parents=True, exist_ok=True)
144+
fut = pool.submit(
145+
download_file, source, str(local_path), session=session
146+
)
147+
fut.add_done_callback(update_pbar_cb)
148+
futures.append(fut)
149+
150+
results = [fut.result() for fut in as_completed(futures)]
151+
152+
all_success = all(results)
148153
if not all_success:
149154
raise RuntimeError("Failed to download cubins")
150155

flashinfer/jit/cubin_loader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def safe_urljoin(base, path):
4444
def download_file(
4545
source: str,
4646
local_path: str,
47-
retries: int = 3,
47+
retries: int = 4,
4848
delay: int = 5,
4949
timeout: int = 10,
5050
lock_timeout: int = 30,
@@ -57,7 +57,7 @@ def download_file(
5757
- source (str): The URL or local file path of the file to download.
5858
- local_path (str): The local file path to save the downloaded/copied file.
5959
- retries (int): Number of retry attempts for URL downloads (default: 3).
60-
- delay (int): Delay in seconds between retries (default: 5).
60+
- delay (int): Initial delay in seconds for exponential backoff (default: 5).
6161
- timeout (int): Timeout for the HTTP request in seconds (default: 10).
6262
- lock_timeout (int): Timeout in seconds for the file lock (default: 30).
6363
@@ -87,7 +87,7 @@ def download_file(
8787
logger.error(f"Failed to copy local file: {e}")
8888
return False
8989

90-
# Handle URL downloads
90+
# Handle URL downloads with exponential backoff
9191
for attempt in range(1, retries + 1):
9292
try:
9393
response = session.get(source, timeout=timeout)
@@ -107,8 +107,9 @@ def download_file(
107107
)
108108

109109
if attempt < retries:
110-
logger.info(f"Retrying in {delay} seconds...")
111-
time.sleep(delay)
110+
backoff_delay = delay * (2 ** (attempt - 1))
111+
logger.info(f"Retrying in {backoff_delay} seconds...")
112+
time.sleep(backoff_delay)
112113
else:
113114
logger.error("Max retries reached. Download failed.")
114115
return False

0 commit comments

Comments
 (0)