Skip to content
182 changes: 182 additions & 0 deletions databusclient/api/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,54 @@
get_databus_id_parts_from_file_url,
)

from databusclient.extensions.webdav import compute_sha256_and_length

def _extract_checksum_from_node(node) -> str | None:
"""
Try to extract a 64-char hex checksum from a JSON-LD file node.
Handles these common shapes:
- checksum or sha256sum fields as plain string
- checksum fields as dict with '@value'
- nested values (recursively search strings for a 64-char hex)
"""
def find_in_value(v):
if isinstance(v, str):
s = v.strip()
if len(s) == 64 and all(c in "0123456789abcdefABCDEF" for c in s):
return s
if isinstance(v, dict):
# common JSON-LD value object
if "@value" in v and isinstance(v["@value"], str):
res = find_in_value(v["@value"])
if res:
return res
# try all nested dict values
for vv in v.values():
res = find_in_value(vv)
if res:
return res
if isinstance(v, list):
for item in v:
res = find_in_value(item)
if res:
return res
return None

# direct keys to try first
for key in ("checksum", "sha256sum", "sha256", "databus:checksum"):
if key in node:
res = find_in_value(node[key])
if res:
return res

# fallback: search all values recursively for a 64-char hex string
for v in node.values():
res = find_in_value(v)
if res:
return res
return None



# Hosts that require Vault token based authentication. Central source of truth.
VAULT_REQUIRED_HOSTS = {
Expand All @@ -32,6 +80,8 @@ def _download_file(
databus_key=None,
auth_url=None,
client_id=None,
validate_checksum: bool = False,
expected_checksum: str | None = None,
) -> None:
"""
Download a file from the internet with a progress bar using tqdm.
Expand Down Expand Up @@ -183,6 +233,25 @@ def _download_file(
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise IOError("Downloaded size does not match Content-Length header")

# --- 6. Optional checksum validation ---
if validate_checksum:
# reuse compute_sha256_and_length from webdav extension
try:
actual, _ = compute_sha256_and_length(filename)
except (OSError, IOError) as e:
print(f"WARNING: error computing checksum for {filename}: {e}")
actual = None

if expected_checksum is None:
print(f"WARNING: no expected checksum available for {filename}; skipping validation")
elif actual is None:
print(f"WARNING: could not compute checksum for {filename}; skipping validation")
else:
if actual.lower() != expected_checksum.lower():
raise IOError(
f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}"
)


def _download_files(
urls: List[str],
Expand All @@ -191,6 +260,8 @@ def _download_files(
databus_key: str = None,
auth_url: str = None,
client_id: str = None,
validate_checksum: bool = False,
checksums: dict | None = None,
) -> None:
"""
Download multiple files from the databus.
Expand All @@ -204,13 +275,18 @@ def _download_files(
- client_id: Client ID for token exchange
"""
for url in urls:
expected = None
if checksums and isinstance(checksums, dict):
expected = checksums.get(url)
_download_file(
url=url,
localDir=localDir,
vault_token_file=vault_token_file,
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
expected_checksum=expected,
)


Expand Down Expand Up @@ -358,6 +434,7 @@ def _download_collection(
databus_key: str = None,
auth_url: str = None,
client_id: str = None,
validate_checksum: bool = False
) -> None:
"""
Download all files in a databus collection.
Expand All @@ -375,13 +452,53 @@ def _download_collection(
file_urls = _get_file_download_urls_from_sparql_query(
endpoint, query, databus_key=databus_key
)

# If checksum validation requested, attempt to build url->checksum mapping
# by fetching the Version JSON-LD for each file's version. We group files
# by their version URI to avoid fetching the same metadata repeatedly.
checksums: dict = {}
if validate_checksum:
# Map version_uri -> list of file urls
versions_map: dict = {}
for fu in file_urls:
try:
h, acc, grp, art, ver, f = get_databus_id_parts_from_file_url(fu)
except Exception:
continue
if ver is None:
continue
if h is None or acc is None or grp is None or art is None:
continue
version_uri = f"https://{h}/{acc}/{grp}/{art}/{ver}"
versions_map.setdefault(version_uri, []).append(fu)

# Fetch each version's JSON-LD once and extract checksums for its files
for version_uri, urls_in_version in versions_map.items():
try:
json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key)
jd = json.loads(json_str)
graph = jd.get("@graph", [])
for node in graph:
if node.get("@type") == "Part":
file_uri = node.get("file")
if not isinstance(file_uri, str):
continue
expected = _extract_checksum_from_node(node)
if expected and file_uri in urls_in_version:
checksums[file_uri] = expected
except Exception:
# Best-effort: if fetching a version fails, skip it
continue

_download_files(
list(file_urls),
localDir,
vault_token_file=vault_token,
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
checksums=checksums if checksums else None,
)


Expand All @@ -392,6 +509,7 @@ def _download_version(
databus_key: str = None,
auth_url: str = None,
client_id: str = None,
validate_checksum: bool = False,
) -> None:
"""
Download all files in a databus artifact version.
Expand All @@ -406,13 +524,31 @@ def _download_version(
"""
json_str = fetch_databus_jsonld(uri, databus_key=databus_key)
file_urls = _get_file_download_urls_from_artifact_jsonld(json_str)
# build url -> checksum mapping from JSON-LD when available
checksums: dict = {}
try:
json_dict = json.loads(json_str)
graph = json_dict.get("@graph", [])
for node in graph:
if node.get("@type") == "Part":
file_uri = node.get("file")
if not isinstance(file_uri, str):
continue
expected = _extract_checksum_from_node(node)
if expected:
checksums[file_uri] = expected
except Exception:
checksums = {}

_download_files(
file_urls,
localDir,
vault_token_file=vault_token_file,
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
checksums=checksums,
)


Expand All @@ -424,6 +560,7 @@ def _download_artifact(
databus_key: str = None,
auth_url: str = None,
client_id: str = None,
validate_checksum: bool = False,
) -> None:
"""
Download files in a databus artifact.
Expand All @@ -445,13 +582,31 @@ def _download_artifact(
print(f"Downloading version: {version_uri}")
json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key)
file_urls = _get_file_download_urls_from_artifact_jsonld(json_str)
# extract checksums for this version
checksums: dict = {}
try:
jd = json.loads(json_str)
graph = jd.get("@graph", [])
for node in graph:
if node.get("@type") == "Part":
file_uri = node.get("file")
if not isinstance(file_uri, str):
continue
expected = _extract_checksum_from_node(node)
if expected:
checksums[file_uri] = expected
except Exception:
checksums = {}

_download_files(
file_urls,
localDir,
vault_token_file=vault_token_file,
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
checksums=checksums,
)


Expand Down Expand Up @@ -527,6 +682,7 @@ def _download_group(
databus_key: str = None,
auth_url: str = None,
client_id: str = None,
validate_checksum: bool = False,
) -> None:
"""
Download files in a databus group.
Expand All @@ -552,6 +708,7 @@ def _download_group(
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
)


Expand Down Expand Up @@ -598,6 +755,7 @@ def download(
all_versions=None,
auth_url="https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token",
client_id="vault-token-exchange",
validate_checksum: bool = False
) -> None:
"""
Download datasets from databus.
Expand Down Expand Up @@ -638,16 +796,36 @@ def download(
databus_key,
auth_url,
client_id,
validate_checksum=validate_checksum,
)
elif file is not None:
print(f"Downloading file: {databusURI}")
# Try to fetch expected checksum from the parent Version metadata
expected = None
if validate_checksum:
try:
version_uri = f"https://{host}/{account}/{group}/{artifact}/{version}"
json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key)
json_dict = json.loads(json_str)
graph = json_dict.get("@graph", [])
for node in graph:
if node.get("file") == databusURI or node.get("@id") == databusURI:
expected = _extract_checksum_from_node(node)
if expected:
break
except Exception as e:
print(f"WARNING: Could not fetch checksum for single file: {e}")

# Call the worker to download the single file (passes expected checksum)
_download_file(
databusURI,
localDir,
vault_token_file=token,
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
expected_checksum=expected,
)
elif version is not None:
print(f"Downloading version: {databusURI}")
Expand All @@ -658,6 +836,7 @@ def download(
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
)
elif artifact is not None:
print(
Expand All @@ -671,6 +850,7 @@ def download(
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
)
elif group is not None and group != "collections":
print(
Expand All @@ -684,6 +864,7 @@ def download(
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
)
elif account is not None:
print("accountId not supported yet") # TODO
Expand All @@ -709,4 +890,5 @@ def download(
databus_key=databus_key,
auth_url=auth_url,
client_id=client_id,
validate_checksum=validate_checksum,
)
11 changes: 9 additions & 2 deletions databusclient/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def deploy(
show_default=True,
help="Client ID for token exchange",
)
@click.option(
"--validate-checksum",
is_flag=True,
help="Validate checksums of downloaded files"
)
def download(
databusuris: List[str],
localdir,
Expand All @@ -167,7 +172,8 @@ def download(
all_versions,
authurl,
clientid,
):
validate_checksum,
):
"""
Download datasets from databus, optionally using vault access if vault options are provided.
"""
Expand All @@ -181,7 +187,8 @@ def download(
all_versions=all_versions,
auth_url=authurl,
client_id=clientid,
)
validate_checksum=validate_checksum
)
except DownloadAuthError as e:
raise click.ClickException(str(e))

Expand Down