77from SPARQLWrapper import SPARQLWrapper , JSON
88from hashlib import sha256
99import os
10+ import re
1011
1112__debug = False
1213
@@ -392,32 +393,93 @@ def deploy(
392393 print (resp .text )
393394
394395
395- def __download_file__ (url , filename ) :
396+ def __download_file__ (url , filename , vault_token_file = None , auth_url = None , client_id = None ) -> None :
396397 """
397398 Download a file from the internet with a progress bar using tqdm.
398399
399400 Parameters:
400401 - url: the URL of the file to download
401402 - filename: the local file path where the file should be saved
403+ - vault_token_file: Path to Vault refresh token file
404+ - auth_url: Keycloak token endpoint URL
405+ - client_id: Client ID for token exchange
402406 """
403407
404- print ("download " + url )
405- os .makedirs (os .path .dirname (filename ), exist_ok = True ) # Create the necessary directories
406- response = requests .get (url , stream = True )
407- total_size_in_bytes = int (response .headers .get ('content-length' , 0 ))
408- block_size = 1024 # 1 Kibibyte
408+ print ("download " + url )
409+ os .makedirs (os .path .dirname (filename ), exist_ok = True ) # Create the necessary directories
410+
411+ headers = {}
412+ if vault_token_file and auth_url and client_id :
413+ headers ["Authorization" ] = f"Bearer { __get_vault_access__ (url , vault_token_file , auth_url , client_id )} "
414+
415+ response = requests .get (url , headers = headers , stream = True )
416+ response .raise_for_status () # Raise an error for bad responses
417+ total_size_in_bytes = int (response .headers .get ('content-length' , 0 ))
418+ block_size = 1024 # 1 Kibibyte
409419
410420 progress_bar = tqdm (total = total_size_in_bytes , unit = 'iB' , unit_scale = True )
411- with open (filename , 'wb' ) as file :
421+ with open (filename , 'wb' ) as file :
412422 for data in response .iter_content (block_size ):
413423 progress_bar .update (len (data ))
414424 file .write (data )
425+
415426 progress_bar .close ()
427+
416428 if total_size_in_bytes != 0 and progress_bar .n != total_size_in_bytes :
417429 print ("ERROR, something went wrong" )
418430
419431
420- def __query_sparql__ (endpoint_url , query )-> dict :
432+ def __get_vault_access__ (download_url : str ,
433+ token_file : str ,
434+ auth_url : str ,
435+ client_id : str ) -> str :
436+ """
437+ Get Vault access token for a protected databus download.
438+ """
439+ # 1. Load refresh token
440+ refresh_token = os .environ .get ("REFRESH_TOKEN" )
441+ if not refresh_token :
442+ if not os .path .exists (token_file ):
443+ raise FileNotFoundError (f"Vault token file not found: { token_file } " )
444+ with open (token_file , "r" ) as f :
445+ refresh_token = f .read ().strip ()
446+ if len (refresh_token ) < 80 :
447+ print (f"Warning: token from { token_file } is short (<80 chars)" )
448+
449+ # 2. Refresh token -> access token
450+ resp = requests .post (auth_url , data = {
451+ "client_id" : client_id ,
452+ "grant_type" : "refresh_token" ,
453+ "refresh_token" : refresh_token
454+ })
455+ resp .raise_for_status ()
456+ access_token = resp .json ()["access_token" ]
457+
458+ # 3. Extract host as audience
459+ # Remove protocol prefix
460+ if download_url .startswith ("https://" ):
461+ host_part = download_url [len ("https://" ):]
462+ elif download_url .startswith ("http://" ):
463+ host_part = download_url [len ("http://" ):]
464+ else :
465+ host_part = download_url
466+ audience = host_part .split ("/" )[0 ] # host is before first "/"
467+
468+ # 4. Access token -> Vault token
469+ resp = requests .post (auth_url , data = {
470+ "client_id" : client_id ,
471+ "grant_type" : "urn:ietf:params:oauth:grant-type:token-exchange" ,
472+ "subject_token" : access_token ,
473+ "audience" : audience
474+ })
475+ resp .raise_for_status ()
476+ vault_token = resp .json ()["access_token" ]
477+
478+ print (f"Using Vault access token for { download_url } " )
479+ return vault_token
480+
481+
482+ def __query_sparql__ (endpoint_url , query ) -> dict :
421483 """
422484 Query a SPARQL endpoint and return results in JSON format.
423485
@@ -436,8 +498,8 @@ def __query_sparql__(endpoint_url, query)-> dict:
436498 return results
437499
438500
439- def __handle__databus_file_query__ (endpoint_url , query ) -> List [str ]:
440- result_dict = __query_sparql__ (endpoint_url ,query )
501+ def __handle_databus_file_query__ (endpoint_url , query ) -> List [str ]:
502+ result_dict = __query_sparql__ (endpoint_url , query )
441503 for binding in result_dict ['results' ]['bindings' ]:
442504 if len (binding .keys ()) > 1 :
443505 print ("Error multiple bindings in query response" )
@@ -447,45 +509,84 @@ def __handle__databus_file_query__(endpoint_url, query) -> List[str]:
447509 yield value
448510
449511
512+ def __handle_databus_file_json__ (json_str : str ) -> List [str ]:
513+ downloadURLs = []
514+ json_dict = json .loads (json_str )
515+ graph = json_dict .get ("@graph" , [])
516+ for node in graph :
517+ if node .get ("@type" ) == "Part" :
518+ downloadURL = node .get ("downloadURL" )
519+ if downloadURL :
520+ downloadURLs .append (downloadURL )
521+ return downloadURLs
522+
523+
450524def wsha256 (raw : str ):
451525 return sha256 (raw .encode ('utf-8' )).hexdigest ()
452526
453527
454- def __handle_databus_collection__ (endpoint , uri : str )-> str :
528+ def __handle_databus_collection__ (uri : str ) -> str :
455529 headers = {"Accept" : "text/sparql" }
456530 return requests .get (uri , headers = headers ).text
457531
458532
459- def __download_list__ (urls : List [str ], localDir : str ):
533+ def __handle_databus_artifact_version__ (uri : str ) -> str :
534+ headers = {"Accept" : "application/ld+json" }
535+ return requests .get (uri , headers = headers ).text
536+
537+
538+ def __download_list__ (urls : List [str ],
539+ localDir : str ,
540+ vault_token_file : str = None ,
541+ auth_url : str = None ,
542+ client_id : str = None ) -> None :
460543 for url in urls :
461- __download_file__ (url = url ,filename = localDir + "/" + wsha256 (url ))
544+ file = url .split ("/" )[- 1 ]
545+ filename = os .path .join (localDir , file )
546+ __download_file__ (url = url , filename = filename , vault_token_file = vault_token_file , auth_url = auth_url , client_id = client_id )
462547
463548
464549def download (
465550 localDir : str ,
466551 endpoint : str ,
467- databusURIs : List [str ]
552+ databusURIs : List [str ],
553+ vault_token_file = None ,
554+ auth_url = None ,
555+ client_id = None
468556) -> None :
469557 """
470- Download datasets to local storage from databus registry
558+ Download datasets to local storage from databus registry. If vault options are provided, vault access will be used for downloading protected files.
471559 ------
472560 localDir: the local directory
561+ endpoint: the databus endpoint URL
473562 databusURIs: identifiers to access databus registered datasets
563+ vault_token_file: Path to Vault refresh token file
564+ auth_url: Keycloak token endpoint URL
565+ client_id: Client ID for token exchange
474566 """
567+
568+ databusVersionPattern = re .compile (r"^https://(databus\.dbpedia\.org|databus\.dev\.dbpedia\.link)/[^/]+/[^/]+/[^/]+/[^/]+/?$" )
569+
475570 for databusURI in databusURIs :
476571 # dataID or databus collection
477572 if databusURI .startswith ("http://" ) or databusURI .startswith ("https://" ):
478573 # databus collection
479- if "/collections/" in databusURI : #TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI
480- query = __handle_databus_collection__ (endpoint ,databusURI )
481- res = __handle__databus_file_query__ (endpoint , query )
574+ if "/collections/" in databusURI : # TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI
575+ query = __handle_databus_collection__ (databusURI )
576+ res = __handle_databus_file_query__ (endpoint , query )
577+ __download_list__ (res , localDir )
578+ # databus artifact version // https://(databus.dbpedia.org|databus.dev.dbpedia.link)/$ACCOUNT/$GROUP/$ARTIFACT/$VERSION
579+ elif databusVersionPattern .match (databusURI ):
580+ json_str = __handle_databus_artifact_version__ (databusURI )
581+ res = __handle_databus_file_json__ (json_str )
582+ __download_list__ (res , localDir , vault_token_file = vault_token_file , auth_url = auth_url , client_id = client_id )
482583 else :
483- print ("dataId not supported yet" ) # TODO add support for other DatabusIds here (artifact, group, etc.)
584+ print ("dataId not supported yet" ) # TODO add support for other DatabusIds here (artifact, group, etc.)
484585 # query in local file
485586 elif databusURI .startswith ("file://" ):
486587 print ("query in file not supported yet" )
487588 # query as argument
488589 else :
489- print ("QUERY {}" , databusURI .replace ("\n " ," " ))
490- res = __handle__databus_file_query__ (endpoint ,databusURI )
491- __download_list__ (res ,localDir )
590+ print ("QUERY {}" , databusURI .replace ("\n " , " " ))
591+ res = __handle_databus_file_query__ (endpoint , databusURI )
592+ __download_list__ (res , localDir )
0 commit comments