|
26 | 26 | from .jit.core import logger
|
27 | 27 | from .jit.cubin_loader import (
|
28 | 28 | FLASHINFER_CUBINS_REPOSITORY,
|
29 |
| - get_cubin, |
| 29 | + download_file, |
30 | 30 | safe_urljoin,
|
31 | 31 | FLASHINFER_CUBIN_DIR,
|
32 | 32 | )
|
@@ -125,26 +125,31 @@ def download_artifacts() -> None:
|
125 | 125 | # HTTPS connections.
|
126 | 126 | session = requests.Session()
|
127 | 127 |
|
128 |
| - with temp_env_var("FLASHINFER_CUBIN_CHECKSUM_DISABLED", "1"): |
129 |
| - cubin_files = list(get_cubin_file_list()) |
130 |
| - num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) |
131 |
| - with tqdm_logging_redirect( |
132 |
| - total=len(cubin_files), desc="Downloading cubins" |
133 |
| - ) as pbar: |
134 |
| - |
135 |
| - def update_pbar_cb(_) -> None: |
136 |
| - pbar.update(1) |
137 |
| - |
138 |
| - with ThreadPoolExecutor(num_threads) as pool: |
139 |
| - futures = [] |
140 |
| - for name in cubin_files: |
141 |
| - fut = pool.submit(get_cubin, name, "", session) |
142 |
| - fut.add_done_callback(update_pbar_cb) |
143 |
| - futures.append(fut) |
144 |
| - |
145 |
| - results = [fut.result() for fut in as_completed(futures)] |
146 |
| - |
147 |
| - all_success = all(results) |
| 128 | + cubin_files = list(get_cubin_file_list()) |
| 129 | + num_threads = int(os.environ.get("FLASHINFER_CUBIN_DOWNLOAD_THREADS", "4")) |
| 130 | + with tqdm_logging_redirect( |
| 131 | + total=len(cubin_files), desc="Downloading cubins" |
| 132 | + ) as pbar: |
| 133 | + |
| 134 | + def update_pbar_cb(_) -> None: |
| 135 | + pbar.update(1) |
| 136 | + |
| 137 | + with ThreadPoolExecutor(num_threads) as pool: |
| 138 | + futures = [] |
| 139 | + for name in cubin_files: |
| 140 | + source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name) |
| 141 | + local_path = FLASHINFER_CUBIN_DIR / name |
| 142 | + # Ensure parent directory exists |
| 143 | + local_path.parent.mkdir(parents=True, exist_ok=True) |
| 144 | + fut = pool.submit( |
| 145 | + download_file, source, str(local_path), session=session |
| 146 | + ) |
| 147 | + fut.add_done_callback(update_pbar_cb) |
| 148 | + futures.append(fut) |
| 149 | + |
| 150 | + results = [fut.result() for fut in as_completed(futures)] |
| 151 | + |
| 152 | + all_success = all(results) |
148 | 153 | if not all_success:
|
149 | 154 | raise RuntimeError("Failed to download cubins")
|
150 | 155 |
|
|
0 commit comments