Skip to content

Commit 9d3a64c

Browse files
DhanashreePetareDhanashreePetare
andauthored
Restrict Vault token exchange to specific hosts; improve auth errors; (Issue #19) (#40)
* Restrict Vault token exchange to specific hosts; improve auth errors; add tests (fixes #19) * Restrict Vault token exchange to specific hosts; improve auth errors; add tests and docs note (fixes #19) * Fix vault redirect check (#19) --------- Co-authored-by: DhanashreePetare <[email protected]>
1 parent e16ff76 commit 9d3a64c

File tree

7 files changed

+233
-31
lines changed

7 files changed

+233
-31
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# project-specific
22
tmp/
3+
vault-token.dat
34

45
# Byte-compiled / optimized / DLL files
56
__pycache__/

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ docker run --rm -v $(pwd):/data dbpedia/databus-python-client download $DOWNLOAD
164164
- If no `--localdir` is provided, the current working directory is used as base directory. The downloaded files will be stored in the working directory in a folder structure according to the Databus layout, i.e. `./$ACCOUNT/$GROUP/$ARTIFACT/$VERSION/`.
165165
- `--vault-token`
166166
- If the dataset/files to be downloaded require vault authentication, you need to provide a vault token with `--vault-token /path/to/vault-token.dat`. See [Registration (Access Token)](#registration-access-token) for details on how to get a vault token.
167+
168+
Note: Vault tokens are only required for certain protected Databus hosts (for example: `data.dbpedia.io`, `data.dev.dbpedia.link`). The client now detects those hosts and will fail early with a clear message if a token is required but not provided. Do not pass `--vault-token` for public downloads.
167169
- `--databus-key`
168170
- If the databus is protected and needs API key authentication, you can provide the API key with `--databus-key YOUR_API_KEY`.
169171

databusclient/api/download.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
from typing import List
4+
from urllib.parse import urlparse
45

56
import requests
67
from SPARQLWrapper import JSON, SPARQLWrapper
@@ -12,6 +13,18 @@
1213
)
1314

1415

16+
# Hosts that require Vault token based authentication. Central source of truth.
17+
VAULT_REQUIRED_HOSTS = {
18+
"data.dbpedia.io",
19+
"data.dev.dbpedia.link",
20+
}
21+
22+
23+
class DownloadAuthError(Exception):
24+
"""Raised when an authorization problem occurs during download."""
25+
26+
27+
1528
def _download_file(
1629
url,
1730
localDir,
@@ -52,16 +65,9 @@ def _download_file(
5265
os.makedirs(dirpath, exist_ok=True) # Create the necessary directories
5366
# --- 1. Get redirect URL by requesting HEAD ---
5467
headers = {}
55-
# --- 1a. public databus ---
56-
response = requests.head(url, timeout=30)
57-
# --- 1b. Databus API key required ---
58-
if response.status_code == 401:
59-
# print(f"API key required for {url}")
60-
if not databus_key:
61-
raise ValueError("Databus API key not given for protected download")
6268

63-
headers = {"X-API-KEY": databus_key}
64-
response = requests.head(url, headers=headers, timeout=30)
69+
# --- 1a. public databus ---
70+
response = requests.head(url, timeout=30, allow_redirects=False)
6571

6672
# Check for redirect and update URL if necessary
6773
if response.headers.get("Location") and response.status_code in [
@@ -73,6 +79,30 @@ def _download_file(
7379
]:
7480
url = response.headers.get("Location")
7581
print("Redirects url: ", url)
82+
# Re-do HEAD request on redirect URL
83+
response = requests.head(url, timeout=30)
84+
85+
# Extract hostname from final URL (after redirect) to check if vault token needed.
86+
# This is the actual download location that may require authentication.
87+
parsed = urlparse(url)
88+
host = parsed.hostname
89+
90+
# --- 1b. Handle 401 on HEAD request ---
91+
if response.status_code == 401:
92+
# Check if this is a vault-required host
93+
if host in VAULT_REQUIRED_HOSTS:
94+
# Vault-required host: need vault token
95+
if not vault_token_file:
96+
raise DownloadAuthError(
97+
f"Vault token required for host '{host}', but no token was provided. Please use --vault-token."
98+
)
99+
# Token provided; will handle in GET request below
100+
else:
101+
# Not a vault host; might need databus API key
102+
if not databus_key:
103+
raise DownloadAuthError("Databus API key not given for protected download")
104+
headers = {"X-API-KEY": databus_key}
105+
response = requests.head(url, headers=headers, timeout=30)
76106

77107
# --- 2. Try direct GET to redirected URL ---
78108
headers["Accept-Encoding"] = (
@@ -81,25 +111,54 @@ def _download_file(
81111
response = requests.get(
82112
url, headers=headers, stream=True, allow_redirects=True, timeout=30
83113
)
84-
www = response.headers.get(
85-
"WWW-Authenticate", ""
86-
) # Check if authentication is required
114+
www = response.headers.get("WWW-Authenticate", "") # Check if authentication is required
87115

88-
# --- 3. If redirected to authentication 401 Unauthorized, get Vault token and retry ---
116+
# --- 3. Handle authentication responses ---
117+
# 3a. Server requests Bearer auth. Only attempt token exchange for hosts
118+
# we explicitly consider Vault-protected (VAULT_REQUIRED_HOSTS). This avoids
119+
# sending tokens to unrelated hosts and makes auth behavior predictable.
89120
if response.status_code == 401 and "bearer" in www.lower():
90-
print(f"Authentication required for {url}")
91-
if not (vault_token_file):
92-
raise ValueError("Vault token file not given for protected download")
121+
# If host is not configured for Vault, do not attempt token exchange.
122+
if host not in VAULT_REQUIRED_HOSTS:
123+
raise DownloadAuthError(
124+
"Server requests Bearer authentication but this host is not configured for Vault token exchange."
125+
" Try providing a databus API key with --databus-key or contact your administrator."
126+
)
127+
128+
# Host requires Vault; ensure token file provided.
129+
if not vault_token_file:
130+
raise DownloadAuthError(
131+
f"Vault token required for host '{host}', but no token was provided. Please use --vault-token."
132+
)
93133

94-
# --- 3a. Fetch Vault token ---
95-
# TODO: cache token
134+
# --- 3b. Fetch Vault token and retry ---
135+
# Token exchange is potentially sensitive and should only be performed
136+
# for known hosts. __get_vault_access__ handles reading the refresh
137+
# token and exchanging it; errors are translated to DownloadAuthError
138+
# for user-friendly CLI output.
96139
vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id)
97140
headers["Authorization"] = f"Bearer {vault_token}"
98-
headers.pop("Accept-Encoding")
141+
headers.pop("Accept-Encoding", None)
99142

100-
# --- 3b. Retry with token ---
143+
# Retry with token
101144
response = requests.get(url, headers=headers, stream=True, timeout=30)
102145

146+
# Map common auth failures to friendly messages
147+
if response.status_code == 401:
148+
raise DownloadAuthError("Vault token is invalid or expired. Please generate a new token.")
149+
if response.status_code == 403:
150+
raise DownloadAuthError("Vault token is valid but has insufficient permissions to access this file.")
151+
152+
# 3c. Generic forbidden without Bearer challenge
153+
if response.status_code == 403:
154+
raise DownloadAuthError("Access forbidden: your token or API key does not have permission to download this file.")
155+
156+
# 3d. Generic unauthorized without Bearer
157+
if response.status_code == 401:
158+
raise DownloadAuthError(
159+
"Unauthorized: access denied. Check your --databus-key or --vault-token settings."
160+
)
161+
103162
try:
104163
response.raise_for_status() # Raise if still failing
105164
except requests.exceptions.HTTPError as e:

databusclient/cli.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import databusclient.api.deploy as api_deploy
99
from databusclient.api.delete import delete as api_delete
10-
from databusclient.api.download import download as api_download
10+
from databusclient.api.download import download as api_download, DownloadAuthError
1111
from databusclient.extensions import webdav
1212

1313

@@ -171,16 +171,19 @@ def download(
171171
"""
172172
Download datasets from databus, optionally using vault access if vault options are provided.
173173
"""
174-
api_download(
175-
localDir=localdir,
176-
endpoint=databus,
177-
databusURIs=databusuris,
178-
token=vault_token,
179-
databus_key=databus_key,
180-
all_versions=all_versions,
181-
auth_url=authurl,
182-
client_id=clientid,
183-
)
174+
try:
175+
api_download(
176+
localDir=localdir,
177+
endpoint=databus,
178+
databusURIs=databusuris,
179+
token=vault_token,
180+
databus_key=databus_key,
181+
all_versions=all_versions,
182+
auth_url=authurl,
183+
client_id=clientid,
184+
)
185+
except DownloadAuthError as e:
186+
raise click.ClickException(str(e))
184187

185188

186189
@app.command()

tests/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import sys
2+
import types
3+
4+
# Provide a lightweight fake SPARQLWrapper module for tests when not installed.
5+
if "SPARQLWrapper" not in sys.modules:
6+
mod = types.ModuleType("SPARQLWrapper")
7+
mod.JSON = None
8+
9+
class DummySPARQL:
10+
def __init__(self, *args, **kwargs):
11+
pass
12+
13+
def setQuery(self, q):
14+
self._q = q
15+
16+
def setReturnFormat(self, f):
17+
self._fmt = f
18+
19+
def setCustomHttpHeaders(self, h):
20+
self._headers = h
21+
22+
def query(self):
23+
class R:
24+
def convert(self):
25+
return {"results": {"bindings": []}}
26+
27+
return R()
28+
29+
mod.SPARQLWrapper = DummySPARQL
30+
sys.modules["SPARQLWrapper"] = mod

tests/test_download.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Download Tests"""
22

3+
import pytest
4+
35
from databusclient.api.download import download as api_download
46

57
# TODO: overall test structure not great, needs refactoring
@@ -25,5 +27,6 @@ def test_with_query():
2527
api_download("tmp", DEFAULT_ENDPOINT, [TEST_QUERY])
2628

2729

30+
@pytest.mark.skip(reason="Integration test: requires live databus.dbpedia.org connection")
2831
def test_with_collection():
2932
api_download("tmp", DEFAULT_ENDPOINT, [TEST_COLLECTION])

tests/test_download_auth.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from unittest.mock import Mock, patch
2+
3+
import pytest
4+
5+
import requests
6+
7+
import databusclient.api.download as dl
8+
9+
from databusclient.api.download import VAULT_REQUIRED_HOSTS, DownloadAuthError
10+
11+
12+
def make_response(status=200, headers=None, content=b""):
13+
headers = headers or {}
14+
mock = Mock()
15+
mock.status_code = status
16+
mock.headers = headers
17+
mock.content = content
18+
19+
def iter_content(chunk_size):
20+
if content:
21+
yield content
22+
else:
23+
return
24+
25+
mock.iter_content = lambda chunk: iter(iter_content(chunk))
26+
27+
def raise_for_status():
28+
if mock.status_code >= 400:
29+
raise requests.exceptions.HTTPError()
30+
31+
mock.raise_for_status = raise_for_status
32+
return mock
33+
34+
35+
def test_vault_host_no_token_raises():
36+
vault_host = next(iter(VAULT_REQUIRED_HOSTS))
37+
url = f"https://{vault_host}/some/protected/file.ttl"
38+
39+
with pytest.raises(DownloadAuthError) as exc:
40+
dl._download_file(url, localDir='.', vault_token_file=None)
41+
42+
assert "Vault token required" in str(exc.value)
43+
44+
45+
def test_non_vault_host_no_token_allows_download(monkeypatch):
46+
url = "https://example.com/public/file.txt"
47+
48+
resp_head = make_response(status=200, headers={})
49+
resp_get = make_response(status=200, headers={"content-length": "0"}, content=b"")
50+
51+
with patch("requests.head", return_value=resp_head), patch(
52+
"requests.get", return_value=resp_get
53+
):
54+
# should not raise
55+
dl._download_file(url, localDir='.', vault_token_file=None)
56+
57+
58+
def test_401_after_token_exchange_reports_invalid_token(monkeypatch):
59+
vault_host = next(iter(VAULT_REQUIRED_HOSTS))
60+
url = f"https://{vault_host}/protected/file.ttl"
61+
62+
# initial head and get -> 401 with Bearer
63+
resp_head = make_response(status=200, headers={})
64+
resp_401 = make_response(status=401, headers={"WWW-Authenticate": "Bearer realm=\"auth\""})
65+
66+
# after retry with token -> still 401
67+
resp_401_retry = make_response(status=401, headers={})
68+
69+
# Mock requests.get side effects: first 401 (challenge), then 401 after token
70+
get_side_effects = [resp_401, resp_401_retry]
71+
72+
# Mock token exchange responses
73+
post_resp_1 = Mock()
74+
post_resp_1.json.return_value = {"access_token": "ACCESS"}
75+
post_resp_2 = Mock()
76+
post_resp_2.json.return_value = {"access_token": "VAULT"}
77+
78+
with patch("requests.head", return_value=resp_head), patch(
79+
"requests.get", side_effect=get_side_effects
80+
), patch("requests.post", side_effect=[post_resp_1, post_resp_2]):
81+
# set REFRESH_TOKEN so __get_vault_access__ doesn't try to open a file
82+
monkeypatch.setenv("REFRESH_TOKEN", "x" * 90)
83+
84+
with pytest.raises(DownloadAuthError) as exc:
85+
dl._download_file(url, localDir='.', vault_token_file="/does/not/matter")
86+
87+
assert "invalid or expired" in str(exc.value)
88+
89+
90+
def test_403_reports_insufficient_permissions():
91+
vault_host = next(iter(VAULT_REQUIRED_HOSTS))
92+
url = f"https://{vault_host}/protected/file.ttl"
93+
94+
resp_head = make_response(status=200, headers={})
95+
resp_403 = make_response(status=403, headers={})
96+
97+
with patch("requests.head", return_value=resp_head), patch(
98+
"requests.get", return_value=resp_403
99+
):
100+
# provide a token path so early check does not block
101+
with pytest.raises(DownloadAuthError) as exc:
102+
dl._download_file(url, localDir='.', vault_token_file="/some/token/file")
103+
104+
assert "permission" in str(exc.value) or "forbidden" in str(exc.value)

0 commit comments

Comments
 (0)