Skip to content

Commit e124b51

Browse files
committed
feat: download of artifact versions and vault files
1 parent 058e5cf commit e124b51

File tree

2 files changed

+138
-24
lines changed

2 files changed

+138
-24
lines changed

databusclient/cli.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ def deploy(
3838
def download(
3939
localDir: str = typer.Option(..., help="local databus folder"),
4040
databus: str = typer.Option(..., help="databus URL"),
41-
databusuris: List[str] = typer.Argument(...,help="any kind of these: databus identifier, databus collection identifier, query file")
41+
databusuris: List[str] = typer.Argument(..., help="any kind of these: databus identifier, databus collection identifier, query file"),
42+
vault_token_file: str = typer.Option(None, help="Path to Vault refresh token file"),
43+
auth_url: str = typer.Option(None, help="Keycloak token endpoint URL"),
44+
client_id: str = typer.Option(None, help="Client ID for token exchange")
4245
):
43-
client.download(localDir=localDir,endpoint=databus,databusURIs=databusuris)
46+
"""
47+
Download datasets from databus, optionally using vault access if vault options are provided.
48+
"""
49+
# Validate vault options: either all three are provided or none
50+
vault_opts = [vault_token_file, auth_url, client_id]
51+
if any(vault_opts) and not all(vault_opts):
52+
raise typer.BadParameter(
53+
"If one of --vault-token-file, --auth-url, or --client-id is specified, all three must be specified."
54+
)
55+
56+
client.download(localDir=localDir, endpoint=databus, databusURIs=databusuris, vault_token_file=vault_token_file, auth_url=auth_url, client_id=client_id)

databusclient/client.py

Lines changed: 123 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from SPARQLWrapper import SPARQLWrapper, JSON
88
from hashlib import sha256
99
import os
10+
import re
1011

1112
__debug = False
1213

@@ -392,32 +393,93 @@ def deploy(
392393
print(resp.text)
393394

394395

395-
def __download_file__(url, filename):
396+
def __download_file__(url, filename, vault_token_file=None, auth_url=None, client_id=None) -> None:
396397
"""
397398
Download a file from the internet with a progress bar using tqdm.
398399
399400
Parameters:
400401
- url: the URL of the file to download
401402
- filename: the local file path where the file should be saved
403+
- vault_token_file: Path to Vault refresh token file
404+
- auth_url: Keycloak token endpoint URL
405+
- client_id: Client ID for token exchange
402406
"""
403407

404-
print("download "+url)
405-
os.makedirs(os.path.dirname(filename), exist_ok=True) # Create the necessary directories
406-
response = requests.get(url, stream=True)
407-
total_size_in_bytes= int(response.headers.get('content-length', 0))
408-
block_size = 1024 # 1 Kibibyte
408+
print("download "+url)
409+
os.makedirs(os.path.dirname(filename), exist_ok=True) # Create the necessary directories
410+
411+
headers = {}
412+
if vault_token_file and auth_url and client_id:
413+
headers["Authorization"] = f"Bearer {__get_vault_access__(url, vault_token_file, auth_url, client_id)}"
414+
415+
response = requests.get(url, headers=headers, stream=True)
416+
response.raise_for_status() # Raise an error for bad responses
417+
total_size_in_bytes = int(response.headers.get('content-length', 0))
418+
block_size = 1024 # 1 Kibibyte
409419

410420
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
411-
with open(filename, 'wb') as file:
421+
with open(filename, 'wb') as file:
412422
for data in response.iter_content(block_size):
413423
progress_bar.update(len(data))
414424
file.write(data)
425+
415426
progress_bar.close()
427+
416428
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
417429
print("ERROR, something went wrong")
418430

419431

420-
def __query_sparql__(endpoint_url, query)-> dict:
432+
def __get_vault_access__(download_url: str,
433+
token_file: str,
434+
auth_url: str,
435+
client_id: str) -> str:
436+
"""
437+
Get Vault access token for a protected databus download.
438+
"""
439+
# 1. Load refresh token
440+
refresh_token = os.environ.get("REFRESH_TOKEN")
441+
if not refresh_token:
442+
if not os.path.exists(token_file):
443+
raise FileNotFoundError(f"Vault token file not found: {token_file}")
444+
with open(token_file, "r") as f:
445+
refresh_token = f.read().strip()
446+
if len(refresh_token) < 80:
447+
print(f"Warning: token from {token_file} is short (<80 chars)")
448+
449+
# 2. Refresh token -> access token
450+
resp = requests.post(auth_url, data={
451+
"client_id": client_id,
452+
"grant_type": "refresh_token",
453+
"refresh_token": refresh_token
454+
})
455+
resp.raise_for_status()
456+
access_token = resp.json()["access_token"]
457+
458+
# 3. Extract host as audience
459+
# Remove protocol prefix
460+
if download_url.startswith("https://"):
461+
host_part = download_url[len("https://"):]
462+
elif download_url.startswith("http://"):
463+
host_part = download_url[len("http://"):]
464+
else:
465+
host_part = download_url
466+
audience = host_part.split("/")[0] # host is before first "/"
467+
468+
# 4. Access token -> Vault token
469+
resp = requests.post(auth_url, data={
470+
"client_id": client_id,
471+
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
472+
"subject_token": access_token,
473+
"audience": audience
474+
})
475+
resp.raise_for_status()
476+
vault_token = resp.json()["access_token"]
477+
478+
print(f"Using Vault access token for {download_url}")
479+
return vault_token
480+
481+
482+
def __query_sparql__(endpoint_url, query) -> dict:
421483
"""
422484
Query a SPARQL endpoint and return results in JSON format.
423485
@@ -436,8 +498,8 @@ def __query_sparql__(endpoint_url, query)-> dict:
436498
return results
437499

438500

439-
def __handle__databus_file_query__(endpoint_url, query) -> List[str]:
440-
result_dict = __query_sparql__(endpoint_url,query)
501+
def __handle_databus_file_query__(endpoint_url, query) -> List[str]:
502+
result_dict = __query_sparql__(endpoint_url, query)
441503
for binding in result_dict['results']['bindings']:
442504
if len(binding.keys()) > 1:
443505
print("Error multiple bindings in query response")
@@ -447,45 +509,84 @@ def __handle__databus_file_query__(endpoint_url, query) -> List[str]:
447509
yield value
448510

449511

512+
def __handle_databus_file_json__(json_str: str) -> List[str]:
513+
downloadURLs = []
514+
json_dict = json.loads(json_str)
515+
graph = json_dict.get("@graph", [])
516+
for node in graph:
517+
if node.get("@type") == "Part":
518+
downloadURL = node.get("downloadURL")
519+
if downloadURL:
520+
downloadURLs.append(downloadURL)
521+
return downloadURLs
522+
523+
450524
def wsha256(raw: str):
451525
return sha256(raw.encode('utf-8')).hexdigest()
452526

453527

454-
def __handle_databus_collection__(endpoint, uri: str)-> str:
528+
def __handle_databus_collection__(uri: str) -> str:
455529
headers = {"Accept": "text/sparql"}
456530
return requests.get(uri, headers=headers).text
457531

458532

459-
def __download_list__(urls: List[str], localDir: str):
533+
def __handle_databus_artifact_version__(uri: str) -> str:
534+
headers = {"Accept": "application/ld+json"}
535+
return requests.get(uri, headers=headers).text
536+
537+
538+
def __download_list__(urls: List[str],
539+
localDir: str,
540+
vault_token_file: str = None,
541+
auth_url: str = None,
542+
client_id: str = None) -> None:
460543
for url in urls:
461-
__download_file__(url=url,filename=localDir+"/"+wsha256(url))
544+
file = url.split("/")[-1]
545+
filename = os.path.join(localDir, file)
546+
__download_file__(url=url, filename=filename, vault_token_file=vault_token_file, auth_url=auth_url, client_id=client_id)
462547

463548

464549
def download(
465550
localDir: str,
466551
endpoint: str,
467-
databusURIs: List[str]
552+
databusURIs: List[str],
553+
vault_token_file=None,
554+
auth_url=None,
555+
client_id=None
468556
) -> None:
469557
"""
470-
Download datasets to local storage from databus registry
558+
Download datasets to local storage from databus registry. If vault options are provided, vault access will be used for downloading protected files.
471559
------
472560
localDir: the local directory
561+
endpoint: the databus endpoint URL
473562
databusURIs: identifiers to access databus registered datasets
563+
vault_token_file: Path to Vault refresh token file
564+
auth_url: Keycloak token endpoint URL
565+
client_id: Client ID for token exchange
474566
"""
567+
568+
databusVersionPattern = re.compile(r"^https://(databus\.dbpedia\.org|databus\.dev\.dbpedia\.link)/[^/]+/[^/]+/[^/]+/[^/]+/?$")
569+
475570
for databusURI in databusURIs:
476571
# dataID or databus collection
477572
if databusURI.startswith("http://") or databusURI.startswith("https://"):
478573
# databus collection
479-
if "/collections/" in databusURI: #TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI
480-
query = __handle_databus_collection__(endpoint,databusURI)
481-
res = __handle__databus_file_query__(endpoint, query)
574+
if "/collections/" in databusURI: # TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI
575+
query = __handle_databus_collection__(databusURI)
576+
res = __handle_databus_file_query__(endpoint, query)
577+
__download_list__(res, localDir)
578+
# databus artifact version // https://(databus.dbpedia.org|databus.dev.dbpedia.link)/$ACCOUNT/$GROUP/$ARTIFACT/$VERSION
579+
elif databusVersionPattern.match(databusURI):
580+
json_str = __handle_databus_artifact_version__(databusURI)
581+
res = __handle_databus_file_json__(json_str)
582+
__download_list__(res, localDir, vault_token_file=vault_token_file, auth_url=auth_url, client_id=client_id)
482583
else:
483-
print("dataId not supported yet") #TODO add support for other DatabusIds here (artifact, group, etc.)
584+
print("dataId not supported yet") # TODO add support for other DatabusIds here (artifact, group, etc.)
484585
# query in local file
485586
elif databusURI.startswith("file://"):
486587
print("query in file not supported yet")
487588
# query as argument
488589
else:
489-
print("QUERY {}", databusURI.replace("\n"," "))
490-
res = __handle__databus_file_query__(endpoint,databusURI)
491-
__download_list__(res,localDir)
590+
print("QUERY {}", databusURI.replace("\n", " "))
591+
res = __handle_databus_file_query__(endpoint, databusURI)
592+
__download_list__(res, localDir)

0 commit comments

Comments
 (0)