Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openeo/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
DEFAULT_JOB_STATUS_POLL_CONNECTION_RETRY_INTERVAL = 30
DEFAULT_JOB_STATUS_POLL_SOFT_ERROR_MAX = 10

CONFORMANCE_JWT_BEARER = "https://api.openeo.org/*/authentication/jwt"

class OpenEoClientException(BaseOpenEoException):
"""Base class for OpenEO client exceptions"""
pass
Expand Down
31 changes: 30 additions & 1 deletion openeo/rest/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Union,
)

from openeo.utils.version import ComparableVersion
from openeo import Connection, DataCube
from openeo.rest.vectorcube import VectorCube
from openeo.utils.http import HTTP_201_CREATED, HTTP_202_ACCEPTED, HTTP_204_NO_CONTENT
Expand Down Expand Up @@ -189,6 +190,14 @@ def setup_file_format(self, name: str, type: str = "output", gis_data_types: Ite
}
self._requests_mock.get(self.connection.build_url("/file_formats"), json=self.file_formats)
return self

def _get_conformance(self, request, context):
return {
"conformsTo": build_conformance(
api_version="1.3.0",
stac_version="1.0.0"
)
}

def _handle_post_result(self, request, context):
"""handler of `POST /result` (synchronous execute)"""
Expand Down Expand Up @@ -424,6 +433,20 @@ def get_status(job_id: str, current_status: str) -> str:

self.job_status_updater = get_status

def build_conformance(
*,
api_version: str = "1.0.0",
stac_version: str = "0.9.0",
) -> list[str]:
conformance = [
"https://api.openeo.org/{api_version}",
"https://api.stacspec.org/v{stac_version}/core",
"https://api.stacspec.org/v{stac_version}/collections"
]
if ComparableVersion(api_version) >= ComparableVersion("1.3.0"):
conformance.append(f"https://api.openeo.org/{api_version}/authentication/jwt")
return conformance


def build_capabilities(
*,
Expand Down Expand Up @@ -470,17 +493,23 @@ def build_capabilities(
endpoints.extend(
[
{"path": "/process_graphs", "methods": ["GET"]},
{"path": "/process_graphs/{process_graph_id", "methods": ["GET", "PUT", "DELETE"]},
{"path": "/process_graphs/{process_graph_id}", "methods": ["GET", "PUT", "DELETE"]},
]
)

conformance = build_conformance(
api_version=api_version,
stac_version=stac_version
)

capabilities = {
"api_version": api_version,
"stac_version": stac_version,
"id": "dummy",
"title": "Dummy openEO back-end",
"description": "Dummy openeEO back-end",
"endpoints": endpoints,
"conformsTo": conformance,
"links": [],
}
return capabilities
13 changes: 9 additions & 4 deletions openeo/rest/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,17 @@ def __call__(self, req: Request) -> Request:
class BasicBearerAuth(BearerAuth):
"""Bearer token for Basic Auth (openEO API 1.0.0 style)"""

def __init__(self, access_token: str):
super().__init__(bearer="basic//{t}".format(t=access_token))
def __init__(self, access_token: str, jwt_conformance: bool = False):
if not jwt_conformance:
access_token = "basic//{t}".format(t=access_token)
super().__init__(bearer=access_token)


class OidcBearerAuth(BearerAuth):
"""Bearer token for OIDC Auth (openEO API 1.0.0 style)"""

def __init__(self, provider_id: str, access_token: str):
super().__init__(bearer="oidc/{p}/{t}".format(p=provider_id, t=access_token))
def __init__(self, provider_id: str, access_token: str, jwt_conformance: bool = False):
if not jwt_conformance:
access_token = "oidc/{p}/{t}".format(p=provider_id, t=access_token)
super().__init__(bearer=access_token)

10 changes: 10 additions & 0 deletions openeo/rest/capabilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Optional, Union
import re

from openeo.internal.jupyter import render_component
from openeo.rest.models import federation_extension
Expand Down Expand Up @@ -36,6 +37,15 @@ def api_version_check(self) -> ComparableVersion:
if not api_version:
raise ApiVersionException("No API version found")
return ComparableVersion(api_version)

def has_conformance(self, uri: str) -> bool:
"""Check if backend provides a given conformance string"""
uri = re.escape(uri).replace('\\*', '[^/]+')
for conformance_uri in self.capabilities.get("conformsTo", []):
if re.match(uri, conformance_uri):
return True
return False


def supports_endpoint(self, path: str, method="GET") -> bool:
"""Check if backend supports given endpoint"""
Expand Down
11 changes: 9 additions & 2 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from openeo.metadata import CollectionMetadata
from openeo.rest import (
DEFAULT_DOWNLOAD_CHUNK_SIZE,
CONFORMANCE_JWT_BEARER,
CapabilitiesException,
OpenEoApiError,
OpenEoClientException,
Expand Down Expand Up @@ -277,8 +278,12 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[
# /credentials/basic is the only endpoint that expects a Basic HTTP auth
auth=HTTPBasicAuth(username, password)
).json()

# check for JWT bearer token conformance
jwt_conformance = self.capabilities().has_conformance(CONFORMANCE_JWT_BEARER)

# Switch to bearer based authentication in further requests.
self.auth = BasicBearerAuth(access_token=resp["access_token"])
self.auth = BasicBearerAuth(access_token=resp["access_token"], jwt_conformance = jwt_conformance)
return self

def _get_oidc_provider(
Expand Down Expand Up @@ -416,7 +421,9 @@ def _authenticate_oidc(
)

token = tokens.access_token
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token)
# check for JWT bearer token conformance
jwt_conformance = self.capabilities().has_conformance(CONFORMANCE_JWT_BEARER)
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token, jwt_conformance=jwt_conformance)
self._oidc_auth_renewer = oidc_auth_renewer
return self

Expand Down
6 changes: 6 additions & 0 deletions tests/rest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def con120(requests_mock, api_capabilities):
con = Connection(API_URL)
return con

@pytest.fixture
def con130(requests_mock, api_capabilities):
requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0", **api_capabilities))
con = Connection(API_URL)
return con


@pytest.fixture
def dummy_backend(requests_mock, con120) -> DummyBackend:
Expand Down
56 changes: 49 additions & 7 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@
API_URL = "https://oeo.test/"

# TODO: eliminate this and replace with `build_capabilities` usage
BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}]
BASIC_ENDPOINTS = [
{"path": "/credentials/basic", "methods": ["GET"]}
]


GEOJSON_POINT_01 = {"type": "Point", "coordinates": [3, 52]}
Expand Down Expand Up @@ -407,8 +409,8 @@ def test_connect_with_session():
],
"https://oeo.test/openeo/1.1.0/",
"1.1.0",
),
(
),
(
[
{"api_version": "0.4.1", "url": "https://oeo.test/openeo/0.4.1/"},
{"api_version": "1.0.0", "url": "https://oeo.test/openeo/1.0.0/"},
Expand Down Expand Up @@ -462,8 +464,8 @@ def test_connect_with_session():
],
"https://oeo.test/openeo/1.1.0/",
"1.1.0",
),
(
),
(
[
{
"api_version": "0.1.0",
Expand Down Expand Up @@ -848,7 +850,6 @@ def test_authenticate_basic(requests_mock, api_version, basic_auth):
assert isinstance(conn.auth, BearerAuth)
assert conn.auth.bearer == "basic//6cc3570k3n"


def test_authenticate_basic_from_config(requests_mock, api_version, auth_config, basic_auth):
requests_mock.get(API_URL, json={"api_version": api_version, "endpoints": BASIC_ENDPOINTS})
auth_config.set_basic_auth(backend=API_URL, username=basic_auth.username, password=basic_auth.password)
Expand All @@ -859,6 +860,18 @@ def test_authenticate_basic_from_config(requests_mock, api_version, auth_config,
assert isinstance(conn.auth, BearerAuth)
assert conn.auth.bearer == "basic//6cc3570k3n"

def test_authenticate_basic_jwt_bearer(requests_mock, basic_auth):
requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0"))

conn = Connection(API_URL)

assert isinstance(conn.auth, NullAuth)
conn.authenticate_basic(username=basic_auth.username, password=basic_auth.password)
capabilities = conn.capabilities()
assert isinstance(conn.auth, BearerAuth)
assert capabilities.api_version() == "1.3.0"
assert capabilities.has_conformance("https://api.openeo.org/*/authentication/jwt") == True
assert conn.auth.bearer == "6cc3570k3n"

@pytest.mark.slow
def test_authenticate_oidc_authorization_code_100_single_implicit(requests_mock, caplog):
Expand All @@ -885,7 +898,6 @@ def test_authenticate_oidc_authorization_code_100_single_implicit(requests_mock,
assert conn.auth.bearer == 'oidc/fauth/' + oidc_mock.state["access_token"]
assert "No OIDC provider given, but only one available: 'fauth'. Using that one." in caplog.text


def test_authenticate_oidc_authorization_code_100_single_wrong_id(requests_mock):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand Down Expand Up @@ -1049,6 +1061,36 @@ def test_authenticate_oidc_auth_code_pkce_flow_client_from_config(requests_mock,
assert conn.auth.bearer == 'oidc/oi/' + oidc_mock.state["access_token"]
assert refresh_token_store.mock_calls == []

@pytest.mark.slow
def test_authenticate_oidc_auth_code_pkce_flow_jwt_bearer(requests_mock, auth_config):
requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0"))
client_id = "myclient"
issuer = "https://oidc.test"
requests_mock.get(API_URL + 'credentials/oidc', json={
"providers": [{"id": "oi", "issuer": issuer, "title": "example", "scopes": ["openid"]}]
})
oidc_mock = OidcMock(
requests_mock=requests_mock,
expected_grant_type="authorization_code",
expected_client_id=client_id,
expected_fields={"scope": "openid"},
oidc_issuer=issuer,
scopes_supported=["openid"],
)
auth_config.set_oidc_client_config(backend=API_URL, provider_id="oi", client_id=client_id)

# With all this set up, kick off the openid connect flow
refresh_token_store = mock.Mock()
conn = Connection(API_URL, refresh_token_store=refresh_token_store)
assert isinstance(conn.auth, NullAuth)
conn.authenticate_oidc_authorization_code(webbrowser_open=oidc_mock.webbrowser_open)
capabilities = conn.capabilities()
assert isinstance(conn.auth, BearerAuth)
assert capabilities.api_version() == "1.3.0"
assert capabilities.has_conformance("https://api.openeo.org/*/authentication/jwt") == True
assert conn.auth.bearer == oidc_mock.state["access_token"]
# TODO: check issuer ("iss") value in parsed jwt. this will require the example jwt to be formatted accordingly
assert refresh_token_store.mock_calls == []

def test_authenticate_oidc_client_credentials(requests_mock):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
Expand Down
Loading