|
| 1 | +# |
| 2 | +# Copyright IBM Corp. 2024 - 2024 |
| 3 | +# SPDX-License-Identifier: MIT |
| 4 | +# |
| 5 | + |
| 6 | +"""File-related utilities.""" |
| 7 | + |
| 8 | +import tempfile |
| 9 | +from pathlib import Path |
| 10 | +from typing import Union |
| 11 | + |
| 12 | +import requests |
| 13 | +from pydantic import AnyHttpUrl, TypeAdapter, ValidationError |
| 14 | + |
| 15 | + |
| 16 | +def resolve_file_source(source: Union[Path, AnyHttpUrl, str]) -> Path: |
| 17 | + """Resolves the source (URL, path) of a file to a local file path. |
| 18 | +
|
| 19 | + If a URL is provided, the content is first downloaded to a temporary local file. |
| 20 | +
|
| 21 | + Args: |
| 22 | + source (Path | AnyHttpUrl | str): The file input source. Can be a path or URL. |
| 23 | +
|
| 24 | + Raises: |
| 25 | + ValueError: If source is of unexpected type. |
| 26 | +
|
| 27 | + Returns: |
| 28 | + Path: The local file path. |
| 29 | + """ |
| 30 | + try: |
| 31 | + http_url: AnyHttpUrl = TypeAdapter(AnyHttpUrl).validate_python(source) |
| 32 | + res = requests.get(http_url, stream=True) |
| 33 | + res.raise_for_status() |
| 34 | + fname = None |
| 35 | + # try to get filename from response header |
| 36 | + if cont_disp := res.headers.get("Content-Disposition"): |
| 37 | + for par in cont_disp.strip().split(";"): |
| 38 | + # currently only handling directive "filename" (not "*filename") |
| 39 | + if (split := par.split("=")) and split[0].strip() == "filename": |
| 40 | + fname = "=".join(split[1:]).strip().strip("'\"") or None |
| 41 | + break |
| 42 | + # otherwise, use name from URL: |
| 43 | + if fname is None: |
| 44 | + fname = Path(http_url.path or "file").name |
| 45 | + local_path = Path(tempfile.mkdtemp()) / fname |
| 46 | + with open(local_path, "wb") as f: |
| 47 | + for chunk in res.iter_content(chunk_size=1024): # using 1-KB chunks |
| 48 | + f.write(chunk) |
| 49 | + except ValidationError: |
| 50 | + try: |
| 51 | + local_path = TypeAdapter(Path).validate_python(source) |
| 52 | + except ValidationError: |
| 53 | + raise ValueError(f"Unexpected source type encountered: {type(source)}") |
| 54 | + return local_path |
0 commit comments