Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openeo/rest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
DEFAULT_JOB_STATUS_POLL_CONNECTION_RETRY_INTERVAL = 30
DEFAULT_JOB_STATUS_POLL_SOFT_ERROR_MAX = 10

CONFORMANCE_JWT_BEARER = "https://api.openeo.org/*/authentication/jwt"

class OpenEoClientException(BaseOpenEoException):
"""Base class for OpenEO client exceptions"""
pass
Expand Down
33 changes: 32 additions & 1 deletion openeo/rest/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Union,
)

from openeo.utils.version import ComparableVersion
from openeo import Connection, DataCube
from openeo.rest.vectorcube import VectorCube
from openeo.utils.http import HTTP_201_CREATED, HTTP_202_ACCEPTED, HTTP_204_NO_CONTENT
Expand Down Expand Up @@ -189,6 +190,14 @@ def setup_file_format(self, name: str, type: str = "output", gis_data_types: Ite
}
self._requests_mock.get(self.connection.build_url("/file_formats"), json=self.file_formats)
return self

def _get_conformance(self, request, context):
return {
"conformsTo": build_conformance(
api_version="1.3.0",
stac_version="1.0.0"
)
}

def _handle_post_result(self, request, context):
"""handler of `POST /result` (synchronous execute)"""
Expand Down Expand Up @@ -424,6 +433,20 @@ def get_status(job_id: str, current_status: str) -> str:

self.job_status_updater = get_status

def build_conformance(
*,
api_version: str = "1.0.0",
stac_version: str = "0.9.0",
) -> list[str]:
conformance = [
"https://api.openeo.org/{api_version}",
"https://api.stacspec.org/v{stac_version}/core",
"https://api.stacspec.org/v{stac_version}/collections"
]
if ComparableVersion(api_version) >= ComparableVersion("1.3.0"):
conformance.append(f"https://api.openeo.org/{api_version}/authentication/jwt")
return conformance


def build_capabilities(
*,
Expand All @@ -441,6 +464,8 @@ def build_capabilities(
"""Build a dummy capabilities document for testing purposes."""

endpoints = []
if basic_auth:
endpoints.append({"path": "/conformance", "methods": ["GET"]})
if basic_auth:
endpoints.append({"path": "/credentials/basic", "methods": ["GET"]})
if oidc_auth:
Expand Down Expand Up @@ -470,17 +495,23 @@ def build_capabilities(
endpoints.extend(
[
{"path": "/process_graphs", "methods": ["GET"]},
{"path": "/process_graphs/{process_graph_id", "methods": ["GET", "PUT", "DELETE"]},
{"path": "/process_graphs/{process_graph_id}", "methods": ["GET", "PUT", "DELETE"]},
]
)

conformance = build_conformance(
api_version=api_version,
stac_version=stac_version
)

capabilities = {
"api_version": api_version,
"stac_version": stac_version,
"id": "dummy",
"title": "Dummy openEO back-end",
"description": "Dummy openeEO back-end",
"endpoints": endpoints,
"conformsTo": conformance,
"links": [],
}
return capabilities
14 changes: 10 additions & 4 deletions openeo/rest/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,18 @@ def __call__(self, req: Request) -> Request:
class BasicBearerAuth(BearerAuth):
"""Bearer token for Basic Auth (openEO API 1.0.0 style)"""

def __init__(self, access_token: str):
super().__init__(bearer="basic//{t}".format(t=access_token))
def __init__(self, access_token: str, jwt_conformance: bool = False):
if not jwt_conformance:
access_token = "basic//{t}".format(t=access_token)
super().__init__(bearer=access_token)


class OidcBearerAuth(BearerAuth):
"""Bearer token for OIDC Auth (openEO API 1.0.0 style)"""

def __init__(self, provider_id: str, access_token: str):
super().__init__(bearer="oidc/{p}/{t}".format(p=provider_id, t=access_token))
def __init__(self, provider_id: str, access_token: str, jwt_conformance: bool = False):
if jwt_conformance:
bearer="{t}"
else:
bearer="oidc/{p}/{t}".format(p=provider_id, t=access_token)
super().__init__(bearer=bearer)
9 changes: 9 additions & 0 deletions openeo/rest/capabilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, List, Optional, Union
from fnmatch import fnmatch

from openeo.internal.jupyter import render_component
from openeo.rest.models import federation_extension
Expand Down Expand Up @@ -36,6 +37,14 @@ def api_version_check(self) -> ComparableVersion:
if not api_version:
raise ApiVersionException("No API version found")
return ComparableVersion(api_version)

def has_conformance(self, conformance: str) -> bool:
"""Check if backend provides a given conformance string"""
for url in self.capabilities.get("conformsTo", []):
if fnmatch(url, conformance):
return True
return False


def supports_endpoint(self, path: str, method="GET") -> bool:
"""Check if backend supports given endpoint"""
Expand Down
11 changes: 9 additions & 2 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from openeo.metadata import CollectionMetadata
from openeo.rest import (
DEFAULT_DOWNLOAD_CHUNK_SIZE,
CONFORMANCE_JWT_BEARER,
CapabilitiesException,
OpenEoApiError,
OpenEoClientException,
Expand Down Expand Up @@ -277,8 +278,12 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[
# /credentials/basic is the only endpoint that expects a Basic HTTP auth
auth=HTTPBasicAuth(username, password)
).json()

# check for JWT bearer token conformance
jwt_conformance = self.capabilities().has_conformance(CONFORMANCE_JWT_BEARER)

# Switch to bearer based authentication in further requests.
self.auth = BasicBearerAuth(access_token=resp["access_token"])
self.auth = BasicBearerAuth(access_token=resp["access_token"], jwt_conformance = jwt_conformance)
return self

def _get_oidc_provider(
Expand Down Expand Up @@ -416,7 +421,9 @@ def _authenticate_oidc(
)

token = tokens.access_token
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token)
# check for JWT bearer token conformance
jwt_conformance = self.capabilities().has_conformance(CONFORMANCE_JWT_BEARER)
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token, jwt_conformance=jwt_conformance)
self._oidc_auth_renewer = oidc_auth_renewer
return self

Expand Down
6 changes: 6 additions & 0 deletions tests/rest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def con120(requests_mock, api_capabilities):
con = Connection(API_URL)
return con

@pytest.fixture
def con130(requests_mock, api_capabilities):
requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0", **api_capabilities))
con = Connection(API_URL)
return con


@pytest.fixture
def dummy_backend(requests_mock, con120) -> DummyBackend:
Expand Down
26 changes: 19 additions & 7 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@
API_URL = "https://oeo.test/"

# TODO: eliminate this and replace with `build_capabilities` usage
BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}]
BASIC_ENDPOINTS = [
{"path": "/credentials/basic", "methods": ["GET"]}
]


GEOJSON_POINT_01 = {"type": "Point", "coordinates": [3, 52]}
Expand Down Expand Up @@ -407,8 +409,8 @@ def test_connect_with_session():
],
"https://oeo.test/openeo/1.1.0/",
"1.1.0",
),
(
),
(
[
{"api_version": "0.4.1", "url": "https://oeo.test/openeo/0.4.1/"},
{"api_version": "1.0.0", "url": "https://oeo.test/openeo/1.0.0/"},
Expand Down Expand Up @@ -462,8 +464,8 @@ def test_connect_with_session():
],
"https://oeo.test/openeo/1.1.0/",
"1.1.0",
),
(
),
(
[
{
"api_version": "0.1.0",
Expand Down Expand Up @@ -848,7 +850,6 @@ def test_authenticate_basic(requests_mock, api_version, basic_auth):
assert isinstance(conn.auth, BearerAuth)
assert conn.auth.bearer == "basic//6cc3570k3n"


def test_authenticate_basic_from_config(requests_mock, api_version, auth_config, basic_auth):
requests_mock.get(API_URL, json={"api_version": api_version, "endpoints": BASIC_ENDPOINTS})
auth_config.set_basic_auth(backend=API_URL, username=basic_auth.username, password=basic_auth.password)
Expand All @@ -859,6 +860,18 @@ def test_authenticate_basic_from_config(requests_mock, api_version, auth_config,
assert isinstance(conn.auth, BearerAuth)
assert conn.auth.bearer == "basic//6cc3570k3n"

def test_authenticate_basic_jwt_bearer(requests_mock, basic_auth):
requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0"))

conn = Connection(API_URL)

assert isinstance(conn.auth, NullAuth)
conn.authenticate_basic(username=basic_auth.username, password=basic_auth.password)
capabilities = conn.capabilities()
assert isinstance(conn.auth, BearerAuth)
assert capabilities.api_version() == "1.3.0"
assert capabilities.has_conformance("https://api.openeo.org/*/authentication/jwt") == True
assert conn.auth.bearer == "6cc3570k3n"

@pytest.mark.slow
def test_authenticate_oidc_authorization_code_100_single_implicit(requests_mock, caplog):
Expand All @@ -885,7 +898,6 @@ def test_authenticate_oidc_authorization_code_100_single_implicit(requests_mock,
assert conn.auth.bearer == 'oidc/fauth/' + oidc_mock.state["access_token"]
assert "No OIDC provider given, but only one available: 'fauth'. Using that one." in caplog.text


def test_authenticate_oidc_authorization_code_100_single_wrong_id(requests_mock):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand Down
69 changes: 43 additions & 26 deletions tests/rest/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,57 @@


@pytest.fixture
def dummy_backend(requests_mock, con120):
def dummy_backend120(requests_mock, con120):
return DummyBackend(requests_mock=requests_mock, connection=con120)

@pytest.fixture
def dummy_backend130(requests_mock, con130):
return DummyBackend(requests_mock=requests_mock, connection=con130)

DUMMY_PG_ADD35 = {
"add35": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True},
}


class TestDummyBackend:
def test_create_job(self, dummy_backend, con120):
assert dummy_backend.batch_jobs == {}
def test_create_job(self, dummy_backend120, con120):
assert dummy_backend120.batch_jobs == {}
_ = con120.create_job(DUMMY_PG_ADD35)
assert dummy_backend.batch_jobs == {
assert dummy_backend120.batch_jobs == {
"job-000": {
"job_id": "job-000",
"pg": {"add35": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}},
"status": "created",
}
}

def test_start_job(self, dummy_backend, con120):
def test_start_job(self, dummy_backend120, con120):
job = con120.create_job(DUMMY_PG_ADD35)
assert dummy_backend.batch_jobs == {
assert dummy_backend120.batch_jobs == {
"job-000": {"job_id": "job-000", "pg": DUMMY_PG_ADD35, "status": "created"},
}
job.start()
assert dummy_backend.batch_jobs == {
assert dummy_backend120.batch_jobs == {
"job-000": {"job_id": "job-000", "pg": DUMMY_PG_ADD35, "status": "finished"},
}

def test_job_status_updater_error(self, dummy_backend, con120):
dummy_backend.job_status_updater = lambda job_id, current_status: "error"
def test_job_status_updater_error(self, dummy_backend120, con120):
dummy_backend120.job_status_updater = lambda job_id, current_status: "error"

job = con120.create_job(DUMMY_PG_ADD35)
assert dummy_backend.batch_jobs["job-000"]["status"] == "created"
assert dummy_backend120.batch_jobs["job-000"]["status"] == "created"
job.start()
assert dummy_backend.batch_jobs["job-000"]["status"] == "error"
assert dummy_backend120.batch_jobs["job-000"]["status"] == "error"

@pytest.mark.parametrize("final", ["finished", "error"])
def test_setup_simple_job_status_flow(self, dummy_backend, con120, final):
dummy_backend.setup_simple_job_status_flow(queued=2, running=3, final=final)
def test_setup_simple_job_status_flow(self, dummy_backend120, con120, final):
dummy_backend120.setup_simple_job_status_flow(queued=2, running=3, final=final)
job = con120.create_job(DUMMY_PG_ADD35)
assert dummy_backend.batch_jobs["job-000"]["status"] == "created"
assert dummy_backend120.batch_jobs["job-000"]["status"] == "created"

# Note that first status update (to "queued" here) is triggered from `start()`, not `status()` like below
job.start()
assert dummy_backend.batch_jobs["job-000"]["status"] == "queued"
assert dummy_backend120.batch_jobs["job-000"]["status"] == "queued"

# Now go through rest of status flow, through `status()` calls
assert job.status() == "queued"
Expand All @@ -66,25 +69,25 @@ def test_setup_simple_job_status_flow(self, dummy_backend, con120, final):
assert job.status() == final
assert job.status() == final

def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend, con120):
def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend120, con120):
"""Test per-job specific final status"""
dummy_backend.setup_simple_job_status_flow(
dummy_backend120.setup_simple_job_status_flow(
queued=2, running=3, final="finished", final_per_job={"job-001": "error"}
)
job0 = con120.create_job(DUMMY_PG_ADD35)
job1 = con120.create_job(DUMMY_PG_ADD35)
job2 = con120.create_job(DUMMY_PG_ADD35)
assert dummy_backend.batch_jobs["job-000"]["status"] == "created"
assert dummy_backend.batch_jobs["job-001"]["status"] == "created"
assert dummy_backend.batch_jobs["job-002"]["status"] == "created"
assert dummy_backend120.batch_jobs["job-000"]["status"] == "created"
assert dummy_backend120.batch_jobs["job-001"]["status"] == "created"
assert dummy_backend120.batch_jobs["job-002"]["status"] == "created"

# Note that first status update (to "queued" here) is triggered from `start()`, not `status()` like below
job0.start()
job1.start()
job2.start()
assert dummy_backend.batch_jobs["job-000"]["status"] == "queued"
assert dummy_backend.batch_jobs["job-001"]["status"] == "queued"
assert dummy_backend.batch_jobs["job-002"]["status"] == "queued"
assert dummy_backend120.batch_jobs["job-000"]["status"] == "queued"
assert dummy_backend120.batch_jobs["job-001"]["status"] == "queued"
assert dummy_backend120.batch_jobs["job-002"]["status"] == "queued"

# Now go through rest of status flow, through `status()` calls
for expected_status in ["queued", "running", "running", "running"]:
Expand All @@ -98,9 +101,23 @@ def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend, con120)
assert job1.status() == "error"
assert job2.status() == "finished"

def test_setup_job_start_failure(self, dummy_backend):
job = dummy_backend.connection.create_job(process_graph={})
dummy_backend.setup_job_start_failure()
def test_setup_job_start_failure(self, dummy_backend120):
job = dummy_backend120.connection.create_job(process_graph={})
dummy_backend120.setup_job_start_failure()
with pytest.raises(OpenEoApiError, match=re.escape("[500] Internal: No job starting for you, buddy")):
job.start()
assert job.status() == "error"

def test_version(self, dummy_backend120, dummy_backend130):
capabilities120 = dummy_backend120.connection.capabilities()
capabilities130 = dummy_backend130.connection.capabilities()

assert capabilities120.api_version() == "1.2.0"
assert capabilities130.api_version() == "1.3.0"

def test_jwt_conformance(self, dummy_backend120, dummy_backend130):
capabilities120 = dummy_backend120.connection.capabilities()
capabilities130 = dummy_backend130.connection.capabilities()

assert capabilities120.has_conformance("https://api.openeo.org/*/authentication/jwt") == False
assert capabilities130.has_conformance("https://api.openeo.org/*/authentication/jwt") == True