diff --git a/databusclient/api/download.py b/databusclient/api/download.py index ac55faa..60b845b 100644 --- a/databusclient/api/download.py +++ b/databusclient/api/download.py @@ -1,5 +1,6 @@ import json import os +import re from typing import List from urllib.parse import urlparse @@ -12,6 +13,52 @@ get_databus_id_parts_from_file_url, ) +from databusclient.api.utils import compute_sha256_and_length + +# compiled regex for SHA-256 hex strings +_SHA256_RE = re.compile(r"^[0-9a-fA-F]{64}$") + +def _extract_checksum_from_node(node) -> str | None: + """ + Try to extract a 64-char hex checksum from a JSON-LD file node. + Handles these common shapes: + - checksum or sha256sum fields as plain string + - checksum fields as dict with '@value' + - nested values under the allowed keys (lists or '@value' objects) + """ + def find_in_value(v): + if isinstance(v, str): + s = v.strip() + if _SHA256_RE.match(s): + return s + if isinstance(v, dict): + # common JSON-LD value object + if "@value" in v and isinstance(v["@value"], str): + res = find_in_value(v["@value"]) + if res: + return res + # try all nested dict values + for vv in v.values(): + res = find_in_value(vv) + if res: + return res + if isinstance(v, list): + for item in v: + res = find_in_value(item) + if res: + return res + return None + + # Only inspect the explicitly allowed keys to avoid false positives. + for key in ("checksum", "sha256sum", "sha256", "databus:checksum"): + if key in node: + res = find_in_value(node[key]) + if res: + return res + + return None + + # Hosts that require Vault token based authentication. Central source of truth. VAULT_REQUIRED_HOSTS = { @@ -25,6 +72,73 @@ class DownloadAuthError(Exception): +def _extract_checksums_from_jsonld(json_str: str) -> dict: + """ + Parse a JSON-LD string and return a mapping of file URI (and @id) -> checksum. + + Uses the existing _extract_checksum_from_node logic to extract checksums + from `Part` nodes. Both the node's `file` and `@id` (if present and a + string) are mapped to the checksum to preserve existing lookup behavior. + """ + try: + jd = json.loads(json_str) + except Exception: + return {} + if isinstance(jd, dict): + graph = jd.get("@graph", []) + elif isinstance(jd, list): + graph = jd + else: + return{} + + checksums: dict = {} + for node in graph: + if node.get("@type") == "Part": + expected = _extract_checksum_from_node(node) + if not expected: + continue + file_uri = node.get("file") + if isinstance(file_uri, str): + checksums[file_uri] = expected + node_id = node.get("@id") + if isinstance(node_id, str): + checksums[node_id] = expected + return checksums + + +def _resolve_checksums_for_urls(file_urls: List[str], databus_key: str | None) -> dict: + """ + Group file URLs by their Version URI, fetch each Version JSON-LD once, + and return a combined url->checksum mapping for the provided URLs. + + Best-effort: failures to fetch or parse individual versions are skipped. + """ + versions_map: dict = {} + for file_url in file_urls: + try: + host, accountId, groupId, artifactId, versionId, fileId = get_databus_id_parts_from_file_url(file_url) + except Exception: + continue + if versionId is None: + continue + if host is None or accountId is None or groupId is None or artifactId is None: + continue + version_uri = f"https://{host}/{accountId}/{groupId}/{artifactId}/{versionId}" + versions_map.setdefault(version_uri, []).append(file_url) + + checksums: dict = {} + for version_uri, urls_in_version in versions_map.items(): + try: + json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) + extracted_checksums = _extract_checksums_from_jsonld(json_str) + for url in urls_in_version: + if url in extracted_checksums: + checksums[url] = extracted_checksums[url] + except Exception: + # Best-effort: skip versions we cannot fetch or parse + continue + return checksums + def _download_file( url, localDir, @@ -32,6 +146,8 @@ def _download_file( databus_key=None, auth_url=None, client_id=None, + validate_checksum: bool = False, + expected_checksum: str | None = None, ) -> None: """ Download a file from the internet with a progress bar using tqdm. @@ -138,7 +254,7 @@ def _download_file( # for user-friendly CLI output. vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) headers["Authorization"] = f"Bearer {vault_token}" - headers.pop("Accept-Encoding", None) + headers["Accept-Encoding"] = "identity" # Retry with token response = requests.get(url, headers=headers, stream=True, timeout=30) @@ -183,6 +299,29 @@ def _download_file( if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: raise IOError("Downloaded size does not match Content-Length header") + # --- 6. Optional checksum validation --- + if validate_checksum: + # reuse compute_sha256_and_length from webdav extension + try: + actual, _ = compute_sha256_and_length(filename) + except (OSError, IOError) as e: + print(f"WARNING: error computing checksum for {filename}: {e}") + actual = None + + if expected_checksum is None: + print(f"WARNING: no expected checksum available for {filename}; skipping validation") + elif actual is None: + print(f"WARNING: could not compute checksum for {filename}; skipping validation") + else: + if actual.lower() != expected_checksum.lower(): + try: + os.remove(filename) # delete corrupted file + except OSError: + pass + raise IOError( + f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}" + ) + def _download_files( urls: List[str], @@ -191,6 +330,8 @@ def _download_files( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, + checksums: dict | None = None, ) -> None: """ Download multiple files from the databus. @@ -204,6 +345,9 @@ def _download_files( - client_id: Client ID for token exchange """ for url in urls: + expected = None + if checksums and isinstance(checksums, dict): + expected = checksums.get(url) _download_file( url=url, localDir=localDir, @@ -211,6 +355,8 @@ def _download_files( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + expected_checksum=expected, ) @@ -358,6 +504,7 @@ def _download_collection( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False ) -> None: """ Download all files in a databus collection. @@ -375,6 +522,12 @@ def _download_collection( file_urls = _get_file_download_urls_from_sparql_query( endpoint, query, databus_key=databus_key ) + + # If checksum validation requested, attempt to build url->checksum mapping + checksums: dict = {} + if validate_checksum: + checksums = _resolve_checksums_for_urls(list(file_urls), databus_key) + _download_files( list(file_urls), localDir, @@ -382,6 +535,8 @@ def _download_collection( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + checksums=checksums if checksums else None, ) @@ -392,6 +547,7 @@ def _download_version( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, ) -> None: """ Download all files in a databus artifact version. @@ -406,6 +562,13 @@ def _download_version( """ json_str = fetch_databus_jsonld(uri, databus_key=databus_key) file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) + # build url -> checksum mapping from JSON-LD when available + checksums: dict = {} + try: + checksums = _extract_checksums_from_jsonld(json_str) + except Exception: + checksums = {} + _download_files( file_urls, localDir, @@ -413,6 +576,8 @@ def _download_version( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + checksums=checksums, ) @@ -424,6 +589,7 @@ def _download_artifact( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, ) -> None: """ Download files in a databus artifact. @@ -445,6 +611,13 @@ def _download_artifact( print(f"Downloading version: {version_uri}") json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) + # extract checksums for this version + checksums: dict = {} + try: + checksums = _extract_checksums_from_jsonld(json_str) + except Exception: + checksums = {} + _download_files( file_urls, localDir, @@ -452,6 +625,8 @@ def _download_artifact( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + checksums=checksums, ) @@ -527,6 +702,7 @@ def _download_group( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, ) -> None: """ Download files in a databus group. @@ -552,6 +728,7 @@ def _download_group( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) @@ -598,6 +775,7 @@ def download( all_versions=None, auth_url="https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token", client_id="vault-token-exchange", + validate_checksum: bool = False ) -> None: """ Download datasets from databus. @@ -638,9 +816,25 @@ def download( databus_key, auth_url, client_id, + validate_checksum=validate_checksum, ) elif file is not None: print(f"Downloading file: {databusURI}") + # Try to fetch expected checksum from the parent Version metadata + expected = None + if validate_checksum: + try: + if version is not None: + version_uri = f"https://{host}/{account}/{group}/{artifact}/{version}" + json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) + checks = _extract_checksums_from_jsonld(json_str) + expected = checks.get(databusURI) or checks.get( + "https://" + databusURI.removeprefix("http://").removeprefix("https://") + ) + except Exception as e: + print(f"WARNING: Could not fetch checksum for single file: {e}") + + # Call the worker to download the single file (passes expected checksum) _download_file( databusURI, localDir, @@ -648,6 +842,8 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + expected_checksum=expected, ) elif version is not None: print(f"Downloading version: {databusURI}") @@ -658,6 +854,7 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) elif artifact is not None: print( @@ -671,6 +868,7 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) elif group is not None and group != "collections": print( @@ -684,6 +882,7 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) elif account is not None: print("accountId not supported yet") # TODO @@ -702,6 +901,14 @@ def download( res = _get_file_download_urls_from_sparql_query( uri_endpoint, databusURI, databus_key=databus_key ) + + # If checksum validation requested, try to build url->checksum mapping + checksums: dict = {} + if validate_checksum: + checksums = _resolve_checksums_for_urls(res, databus_key) + if not checksums: + print("WARNING: Checksum validation enabled but no checksums found for query results.") + _download_files( res, localDir, @@ -709,4 +916,6 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + checksums=checksums if checksums else None, ) diff --git a/databusclient/api/utils.py b/databusclient/api/utils.py index 7e27ff3..a1a1063 100644 --- a/databusclient/api/utils.py +++ b/databusclient/api/utils.py @@ -1,5 +1,5 @@ from typing import Optional, Tuple - +import hashlib import requests @@ -48,3 +48,15 @@ def fetch_databus_jsonld(uri: str, databus_key: str | None = None) -> str: response.raise_for_status() return response.text + +def compute_sha256_and_length(filepath): + sha256 = hashlib.sha256() + total_length = 0 + with open(filepath, "rb") as f: + while True: + chunk = f.read(4096) + if not chunk: + break + sha256.update(chunk) + total_length += len(chunk) + return sha256.hexdigest(), total_length diff --git a/databusclient/cli.py b/databusclient/cli.py index 069408e..420530d 100644 --- a/databusclient/cli.py +++ b/databusclient/cli.py @@ -158,6 +158,11 @@ def deploy( show_default=True, help="Client ID for token exchange", ) +@click.option( + "--validate-checksum", + is_flag=True, + help="Validate checksums of downloaded files" +) def download( databusuris: List[str], localdir, @@ -167,7 +172,8 @@ def download( all_versions, authurl, clientid, -): + validate_checksum, +): """ Download datasets from databus, optionally using vault access if vault options are provided. """ @@ -181,7 +187,8 @@ def download( all_versions=all_versions, auth_url=authurl, client_id=clientid, - ) + validate_checksum=validate_checksum + ) except DownloadAuthError as e: raise click.ClickException(str(e))