|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from dataclasses import dataclass |
4 | | -from typing import Literal, Any |
| 4 | +from typing import Literal |
5 | 5 |
|
6 | 6 | import os |
7 | 7 | import json |
8 | | -import requests |
9 | | -import threading |
10 | | -import logging |
11 | | -from urllib.parse import urlparse |
12 | 8 |
|
13 | 9 |
|
14 | 10 | def fill_templated_filename(filename: str, output_type: str | None) -> str: |
@@ -111,15 +107,9 @@ class SafetensorRemote: |
111 | 107 | print(data) |
112 | 108 | """ |
113 | 109 |
|
114 | | - logger = logging.getLogger("safetensor_remote") |
115 | | - |
116 | 110 | BASE_DOMAIN = "https://huggingface.co" |
117 | 111 | ALIGNMENT = 8 # bytes |
118 | 112 |
|
119 | | - # start using multithread download for files larger than 100MB |
120 | | - MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024 |
121 | | - MULTITHREAD_COUNT = 8 # number of threads |
122 | | - |
123 | 113 | @classmethod |
124 | 114 | def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: |
125 | 115 | """ |
@@ -221,153 +211,47 @@ def get_metadata(cls, url: str) -> tuple[dict, int]: |
221 | 211 | except json.JSONDecodeError as e: |
222 | 212 | raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") |
223 | 213 |
|
224 | | - @classmethod |
225 | | - def _get_request_headers(cls) -> dict[str, str]: |
226 | | - """Prepare common headers for requests.""" |
227 | | - headers = {"User-Agent": "convert_hf_to_gguf"} |
228 | | - if os.environ.get("HF_TOKEN"): |
229 | | - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
230 | | - return headers |
231 | | - |
232 | 214 | @classmethod |
233 | 215 | def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: |
234 | 216 | """ |
235 | | - Get raw byte data from a remote file by range using single or multi-threaded download. |
236 | | -
|
237 | | - If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only). |
238 | | - If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads. |
239 | | - Otherwise, it uses a single request. |
| 217 | + Get raw byte data from a remote file by range. |
| 218 | + If size is not specified, it will read the entire file. |
240 | 219 | """ |
| 220 | + import requests |
| 221 | + from urllib.parse import urlparse |
| 222 | + |
241 | 223 | parsed_url = urlparse(url) |
242 | 224 | if not parsed_url.scheme or not parsed_url.netloc: |
243 | 225 | raise ValueError(f"Invalid URL: {url}") |
244 | 226 |
|
245 | | - common_headers = cls._get_request_headers() |
246 | | - |
247 | | - # --- Multithreading Path --- |
248 | | - if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: |
249 | | - cls.logger.info(f"Using {cls.MULTITHREAD_COUNT} threads to download range of {size / (1024*1024):.2f} MB") |
250 | | - num_threads = cls.MULTITHREAD_COUNT |
251 | | - results: list[Any] = [None] * num_threads # Store results or exceptions |
252 | | - threads: list[threading.Thread] = [] |
253 | | - |
254 | | - def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict): |
255 | | - """Worker function for thread.""" |
256 | | - thread_headers = headers.copy() |
257 | | - # Range header is inclusive end byte |
258 | | - range_end = chunk_start + chunk_size - 1 |
259 | | - thread_headers["Range"] = f"bytes={chunk_start}-{range_end}" |
260 | | - try: |
261 | | - # Using stream=False should make requests wait for content download |
262 | | - response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout |
263 | | - response.raise_for_status() # Check for HTTP errors |
264 | | - |
265 | | - content = response.content |
266 | | - if len(content) != chunk_size: |
267 | | - # This is a critical check |
268 | | - raise IOError( |
269 | | - f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. " |
270 | | - f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}" |
271 | | - ) |
272 | | - result_list[index] = content |
273 | | - except Exception as e: |
274 | | - # Store exception to be raised by the main thread |
275 | | - # print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print |
276 | | - result_list[index] = e |
277 | | - |
278 | | - # Calculate chunk sizes and create/start threads |
279 | | - base_chunk_size = size // num_threads |
280 | | - remainder = size % num_threads |
281 | | - current_offset = start |
282 | | - |
283 | | - for i in range(num_threads): |
284 | | - chunk_size = base_chunk_size + (1 if i < remainder else 0) |
285 | | - if chunk_size == 0: # Should not happen if size >= threshold but handle defensively |
286 | | - results[i] = b"" # Store empty bytes for this "chunk" |
287 | | - continue |
288 | | - |
289 | | - thread = threading.Thread( |
290 | | - target=download_chunk, |
291 | | - args=(url, current_offset, chunk_size, i, results, common_headers), |
292 | | - daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this) |
293 | | - ) |
294 | | - threads.append(thread) |
295 | | - thread.start() |
296 | | - current_offset += chunk_size # Move offset for the next chunk |
297 | | - |
298 | | - # Wait for all threads to complete |
299 | | - for i, thread in enumerate(threads): |
300 | | - thread.join() # Wait indefinitely for each thread |
301 | | - |
302 | | - # Check results for errors and concatenate chunks |
303 | | - final_data_parts = [] |
304 | | - for i in range(num_threads): |
305 | | - result = results[i] |
306 | | - if isinstance(result, Exception): |
307 | | - # Raise the first exception encountered |
308 | | - raise result |
309 | | - elif result is None: |
310 | | - # This indicates a thread finished without setting its result or exception (unexpected) |
311 | | - # Check if it was supposed to download anything |
312 | | - expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) |
313 | | - if expected_chunk_size > 0: |
314 | | - raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") |
315 | | - else: |
316 | | - final_data_parts.append(b"") # Append empty bytes for zero-size chunk |
317 | | - else: |
318 | | - final_data_parts.append(result) |
319 | | - |
320 | | - # Combine the byte chunks |
321 | | - final_data = b"".join(final_data_parts) |
322 | | - |
323 | | - # Final validation: Does the combined size match the requested size? |
324 | | - if len(final_data) != size: |
325 | | - raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") |
326 | | - |
327 | | - return final_data |
328 | | - |
329 | | - # --- Single-threaded Path --- |
330 | | - else: |
331 | | - # print(f"Using single thread for size {size}") # Optional debug print |
332 | | - headers = common_headers.copy() |
333 | | - if size > -1: |
334 | | - # Range header uses inclusive end byte |
335 | | - range_end = start + size - 1 |
336 | | - headers["Range"] = f"bytes={start}-{range_end}" |
337 | | - elif start > 0: |
338 | | - # Request from start offset to the end of the file |
339 | | - headers["Range"] = f"bytes={start}-" |
340 | | - # If start=0 and size=-1, no Range header is needed (get full file) |
341 | | - |
342 | | - response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout |
343 | | - response.raise_for_status() |
344 | | - content = response.content |
345 | | - |
346 | | - # Validate downloaded size if a specific size was requested |
347 | | - if size > -1 and len(content) != size: |
348 | | - # Check status code - 206 Partial Content is expected for successful range requests |
349 | | - status_code = response.status_code |
350 | | - content_range = response.headers.get('Content-Range') |
351 | | - raise IOError( |
352 | | - f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), " |
353 | | - f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}" |
354 | | - ) |
355 | | - |
356 | | - return content |
| 227 | + headers = {} |
| 228 | + if os.environ.get("HF_TOKEN"): |
| 229 | + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
| 230 | + if size > -1: |
| 231 | + headers["Range"] = f"bytes={start}-{start + size}" |
| 232 | + response = requests.get(url, allow_redirects=True, headers=headers) |
| 233 | + response.raise_for_status() |
| 234 | + |
| 235 | + # Get raw byte data |
| 236 | + return response.content[:size] |
357 | 237 |
|
358 | 238 | @classmethod |
359 | 239 | def check_file_exist(cls, url: str) -> bool: |
360 | 240 | """ |
361 | 241 | Check if a file exists at the given URL. |
362 | 242 | Returns True if the file exists, False otherwise. |
363 | 243 | """ |
| 244 | + import requests |
| 245 | + from urllib.parse import urlparse |
| 246 | + |
364 | 247 | parsed_url = urlparse(url) |
365 | 248 | if not parsed_url.scheme or not parsed_url.netloc: |
366 | 249 | raise ValueError(f"Invalid URL: {url}") |
367 | 250 |
|
368 | 251 | try: |
369 | | - headers = cls._get_request_headers() |
370 | | - headers["Range"] = "bytes=0-0" # Request a small range to check existence |
| 252 | + headers = {"Range": "bytes=0-0"} |
| 253 | + if os.environ.get("HF_TOKEN"): |
| 254 | + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
371 | 255 | response = requests.head(url, allow_redirects=True, headers=headers) |
372 | 256 | # Success (2xx) or redirect (3xx) |
373 | 257 | return 200 <= response.status_code < 400 |
|
0 commit comments