diff --git a/openeo/rest/__init__.py b/openeo/rest/__init__.py index 37b3a8170..dac500c00 100644 --- a/openeo/rest/__init__.py +++ b/openeo/rest/__init__.py @@ -10,6 +10,8 @@ DEFAULT_JOB_STATUS_POLL_CONNECTION_RETRY_INTERVAL = 30 DEFAULT_JOB_STATUS_POLL_SOFT_ERROR_MAX = 10 +CONFORMANCE_JWT_BEARER = "https://api.openeo.org/*/authentication/jwt" + class OpenEoClientException(BaseOpenEoException): """Base class for OpenEO client exceptions""" pass diff --git a/openeo/rest/_testing.py b/openeo/rest/_testing.py index 998874551..f67179c81 100644 --- a/openeo/rest/_testing.py +++ b/openeo/rest/_testing.py @@ -14,6 +14,7 @@ Union, ) +from openeo.utils.version import ComparableVersion from openeo import Connection, DataCube from openeo.rest.vectorcube import VectorCube from openeo.utils.http import HTTP_201_CREATED, HTTP_202_ACCEPTED, HTTP_204_NO_CONTENT @@ -189,6 +190,14 @@ def setup_file_format(self, name: str, type: str = "output", gis_data_types: Ite } self._requests_mock.get(self.connection.build_url("/file_formats"), json=self.file_formats) return self + + def _get_conformance(self, request, context): + return { + "conformsTo": build_conformance( + api_version="1.3.0", + stac_version="1.0.0" + ) + } def _handle_post_result(self, request, context): """handler of `POST /result` (synchronous execute)""" @@ -424,6 +433,20 @@ def get_status(job_id: str, current_status: str) -> str: self.job_status_updater = get_status +def build_conformance( + *, + api_version: str = "1.0.0", + stac_version: str = "0.9.0", +) -> list[str]: + conformance = [ + "https://api.openeo.org/{api_version}", + "https://api.stacspec.org/v{stac_version}/core", + "https://api.stacspec.org/v{stac_version}/collections" + ] + if ComparableVersion(api_version) >= ComparableVersion("1.3.0"): + conformance.append(f"https://api.openeo.org/{api_version}/authentication/jwt") + return conformance + def build_capabilities( *, @@ -470,10 +493,15 @@ def build_capabilities( endpoints.extend( [ {"path": "/process_graphs", "methods": ["GET"]}, - {"path": "/process_graphs/{process_graph_id", "methods": ["GET", "PUT", "DELETE"]}, + {"path": "/process_graphs/{process_graph_id}", "methods": ["GET", "PUT", "DELETE"]}, ] ) + conformance = build_conformance( + api_version=api_version, + stac_version=stac_version + ) + capabilities = { "api_version": api_version, "stac_version": stac_version, @@ -481,6 +509,7 @@ def build_capabilities( "title": "Dummy openEO back-end", "description": "Dummy openeEO back-end", "endpoints": endpoints, + "conformsTo": conformance, "links": [], } return capabilities diff --git a/openeo/rest/auth/auth.py b/openeo/rest/auth/auth.py index 378fbdbc2..7a4684d2e 100644 --- a/openeo/rest/auth/auth.py +++ b/openeo/rest/auth/auth.py @@ -41,12 +41,17 @@ def __call__(self, req: Request) -> Request: class BasicBearerAuth(BearerAuth): """Bearer token for Basic Auth (openEO API 1.0.0 style)""" - def __init__(self, access_token: str): - super().__init__(bearer="basic//{t}".format(t=access_token)) + def __init__(self, access_token: str, jwt_conformance: bool = False): + if not jwt_conformance: + access_token = "basic//{t}".format(t=access_token) + super().__init__(bearer=access_token) class OidcBearerAuth(BearerAuth): """Bearer token for OIDC Auth (openEO API 1.0.0 style)""" - def __init__(self, provider_id: str, access_token: str): - super().__init__(bearer="oidc/{p}/{t}".format(p=provider_id, t=access_token)) + def __init__(self, provider_id: str, access_token: str, jwt_conformance: bool = False): + if not jwt_conformance: + access_token = "oidc/{p}/{t}".format(p=provider_id, t=access_token) + super().__init__(bearer=access_token) + diff --git a/openeo/rest/capabilities.py b/openeo/rest/capabilities.py index 768093f6f..96062bd56 100644 --- a/openeo/rest/capabilities.py +++ b/openeo/rest/capabilities.py @@ -1,4 +1,5 @@ from typing import Dict, List, Optional, Union +import re from openeo.internal.jupyter import render_component from openeo.rest.models import federation_extension @@ -36,6 +37,15 @@ def api_version_check(self) -> ComparableVersion: if not api_version: raise ApiVersionException("No API version found") return ComparableVersion(api_version) + + def has_conformance(self, uri: str) -> bool: + """Check if backend provides a given conformance string""" + uri = re.escape(uri).replace('\\*', '[^/]+') + for conformance_uri in self.capabilities.get("conformsTo", []): + if re.match(uri, conformance_uri): + return True + return False + def supports_endpoint(self, path: str, method="GET") -> bool: """Check if backend supports given endpoint""" diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index d4d4d5995..768841c9b 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -47,6 +47,7 @@ from openeo.metadata import CollectionMetadata from openeo.rest import ( DEFAULT_DOWNLOAD_CHUNK_SIZE, + CONFORMANCE_JWT_BEARER, CapabilitiesException, OpenEoApiError, OpenEoClientException, @@ -277,8 +278,12 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[ # /credentials/basic is the only endpoint that expects a Basic HTTP auth auth=HTTPBasicAuth(username, password) ).json() + + # check for JWT bearer token conformance + jwt_conformance = self.capabilities().has_conformance(CONFORMANCE_JWT_BEARER) + # Switch to bearer based authentication in further requests. - self.auth = BasicBearerAuth(access_token=resp["access_token"]) + self.auth = BasicBearerAuth(access_token=resp["access_token"], jwt_conformance = jwt_conformance) return self def _get_oidc_provider( @@ -416,7 +421,9 @@ def _authenticate_oidc( ) token = tokens.access_token - self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token) + # check for JWT bearer token conformance + jwt_conformance = self.capabilities().has_conformance(CONFORMANCE_JWT_BEARER) + self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token, jwt_conformance=jwt_conformance) self._oidc_auth_renewer = oidc_auth_renewer return self diff --git a/tests/rest/conftest.py b/tests/rest/conftest.py index 2255cca85..411c82e1e 100644 --- a/tests/rest/conftest.py +++ b/tests/rest/conftest.py @@ -99,6 +99,12 @@ def con120(requests_mock, api_capabilities): con = Connection(API_URL) return con +@pytest.fixture +def con130(requests_mock, api_capabilities): + requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0", **api_capabilities)) + con = Connection(API_URL) + return con + @pytest.fixture def dummy_backend(requests_mock, con120) -> DummyBackend: diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 6da731f71..d000dafa4 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -59,7 +59,9 @@ API_URL = "https://oeo.test/" # TODO: eliminate this and replace with `build_capabilities` usage -BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}] +BASIC_ENDPOINTS = [ + {"path": "/credentials/basic", "methods": ["GET"]} +] GEOJSON_POINT_01 = {"type": "Point", "coordinates": [3, 52]} @@ -407,8 +409,8 @@ def test_connect_with_session(): ], "https://oeo.test/openeo/1.1.0/", "1.1.0", - ), - ( + ), + ( [ {"api_version": "0.4.1", "url": "https://oeo.test/openeo/0.4.1/"}, {"api_version": "1.0.0", "url": "https://oeo.test/openeo/1.0.0/"}, @@ -462,8 +464,8 @@ def test_connect_with_session(): ], "https://oeo.test/openeo/1.1.0/", "1.1.0", - ), - ( + ), + ( [ { "api_version": "0.1.0", @@ -860,6 +862,19 @@ def test_authenticate_basic_from_config(requests_mock, api_version, auth_config, assert conn.auth.bearer == "basic//6cc3570k3n" +def test_authenticate_basic_jwt_bearer(requests_mock, basic_auth): + requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0")) + + conn = Connection(API_URL) + + assert isinstance(conn.auth, NullAuth) + conn.authenticate_basic(username=basic_auth.username, password=basic_auth.password) + capabilities = conn.capabilities() + assert isinstance(conn.auth, BearerAuth) + assert capabilities.api_version() == "1.3.0" + assert capabilities.has_conformance("https://api.openeo.org/*/authentication/jwt") == True + assert conn.auth.bearer == "6cc3570k3n" + @pytest.mark.slow def test_authenticate_oidc_authorization_code_100_single_implicit(requests_mock, caplog): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) @@ -1049,6 +1064,36 @@ def test_authenticate_oidc_auth_code_pkce_flow_client_from_config(requests_mock, assert conn.auth.bearer == 'oidc/oi/' + oidc_mock.state["access_token"] assert refresh_token_store.mock_calls == [] +@pytest.mark.slow +def test_authenticate_oidc_auth_code_pkce_flow_jwt_bearer(requests_mock, auth_config): + requests_mock.get(API_URL, json=build_capabilities(api_version="1.3.0")) + client_id = "myclient" + issuer = "https://oidc.test" + requests_mock.get(API_URL + 'credentials/oidc', json={ + "providers": [{"id": "oi", "issuer": issuer, "title": "example", "scopes": ["openid"]}] + }) + oidc_mock = OidcMock( + requests_mock=requests_mock, + expected_grant_type="authorization_code", + expected_client_id=client_id, + expected_fields={"scope": "openid"}, + oidc_issuer=issuer, + scopes_supported=["openid"], + ) + auth_config.set_oidc_client_config(backend=API_URL, provider_id="oi", client_id=client_id) + + # With all this set up, kick off the openid connect flow + refresh_token_store = mock.Mock() + conn = Connection(API_URL, refresh_token_store=refresh_token_store) + assert isinstance(conn.auth, NullAuth) + conn.authenticate_oidc_authorization_code(webbrowser_open=oidc_mock.webbrowser_open) + capabilities = conn.capabilities() + assert isinstance(conn.auth, BearerAuth) + assert capabilities.api_version() == "1.3.0" + assert capabilities.has_conformance("https://api.openeo.org/*/authentication/jwt") == True + assert conn.auth.bearer == oidc_mock.state["access_token"] + # TODO: check issuer ("iss") value in parsed jwt. this will require the example jwt to be formatted accordingly + assert refresh_token_store.mock_calls == [] def test_authenticate_oidc_client_credentials(requests_mock): requests_mock.get(API_URL, json={"api_version": "1.0.0"}) diff --git a/tests/rest/test_testing.py b/tests/rest/test_testing.py index 589dda3dc..7b11ecd22 100644 --- a/tests/rest/test_testing.py +++ b/tests/rest/test_testing.py @@ -7,9 +7,12 @@ @pytest.fixture -def dummy_backend(requests_mock, con120): +def dummy_backend120(requests_mock, con120): return DummyBackend(requests_mock=requests_mock, connection=con120) +@pytest.fixture +def dummy_backend130(requests_mock, con130): + return DummyBackend(requests_mock=requests_mock, connection=con130) DUMMY_PG_ADD35 = { "add35": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}, @@ -17,10 +20,10 @@ def dummy_backend(requests_mock, con120): class TestDummyBackend: - def test_create_job(self, dummy_backend, con120): - assert dummy_backend.batch_jobs == {} + def test_create_job(self, dummy_backend120, con120): + assert dummy_backend120.batch_jobs == {} _ = con120.create_job(DUMMY_PG_ADD35) - assert dummy_backend.batch_jobs == { + assert dummy_backend120.batch_jobs == { "job-000": { "job_id": "job-000", "pg": {"add35": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}, @@ -28,33 +31,33 @@ def test_create_job(self, dummy_backend, con120): } } - def test_start_job(self, dummy_backend, con120): + def test_start_job(self, dummy_backend120, con120): job = con120.create_job(DUMMY_PG_ADD35) - assert dummy_backend.batch_jobs == { + assert dummy_backend120.batch_jobs == { "job-000": {"job_id": "job-000", "pg": DUMMY_PG_ADD35, "status": "created"}, } job.start() - assert dummy_backend.batch_jobs == { + assert dummy_backend120.batch_jobs == { "job-000": {"job_id": "job-000", "pg": DUMMY_PG_ADD35, "status": "finished"}, } - def test_job_status_updater_error(self, dummy_backend, con120): - dummy_backend.job_status_updater = lambda job_id, current_status: "error" + def test_job_status_updater_error(self, dummy_backend120, con120): + dummy_backend120.job_status_updater = lambda job_id, current_status: "error" job = con120.create_job(DUMMY_PG_ADD35) - assert dummy_backend.batch_jobs["job-000"]["status"] == "created" + assert dummy_backend120.batch_jobs["job-000"]["status"] == "created" job.start() - assert dummy_backend.batch_jobs["job-000"]["status"] == "error" + assert dummy_backend120.batch_jobs["job-000"]["status"] == "error" @pytest.mark.parametrize("final", ["finished", "error"]) - def test_setup_simple_job_status_flow(self, dummy_backend, con120, final): - dummy_backend.setup_simple_job_status_flow(queued=2, running=3, final=final) + def test_setup_simple_job_status_flow(self, dummy_backend120, con120, final): + dummy_backend120.setup_simple_job_status_flow(queued=2, running=3, final=final) job = con120.create_job(DUMMY_PG_ADD35) - assert dummy_backend.batch_jobs["job-000"]["status"] == "created" + assert dummy_backend120.batch_jobs["job-000"]["status"] == "created" # Note that first status update (to "queued" here) is triggered from `start()`, not `status()` like below job.start() - assert dummy_backend.batch_jobs["job-000"]["status"] == "queued" + assert dummy_backend120.batch_jobs["job-000"]["status"] == "queued" # Now go through rest of status flow, through `status()` calls assert job.status() == "queued" @@ -66,25 +69,25 @@ def test_setup_simple_job_status_flow(self, dummy_backend, con120, final): assert job.status() == final assert job.status() == final - def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend, con120): + def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend120, con120): """Test per-job specific final status""" - dummy_backend.setup_simple_job_status_flow( + dummy_backend120.setup_simple_job_status_flow( queued=2, running=3, final="finished", final_per_job={"job-001": "error"} ) job0 = con120.create_job(DUMMY_PG_ADD35) job1 = con120.create_job(DUMMY_PG_ADD35) job2 = con120.create_job(DUMMY_PG_ADD35) - assert dummy_backend.batch_jobs["job-000"]["status"] == "created" - assert dummy_backend.batch_jobs["job-001"]["status"] == "created" - assert dummy_backend.batch_jobs["job-002"]["status"] == "created" + assert dummy_backend120.batch_jobs["job-000"]["status"] == "created" + assert dummy_backend120.batch_jobs["job-001"]["status"] == "created" + assert dummy_backend120.batch_jobs["job-002"]["status"] == "created" # Note that first status update (to "queued" here) is triggered from `start()`, not `status()` like below job0.start() job1.start() job2.start() - assert dummy_backend.batch_jobs["job-000"]["status"] == "queued" - assert dummy_backend.batch_jobs["job-001"]["status"] == "queued" - assert dummy_backend.batch_jobs["job-002"]["status"] == "queued" + assert dummy_backend120.batch_jobs["job-000"]["status"] == "queued" + assert dummy_backend120.batch_jobs["job-001"]["status"] == "queued" + assert dummy_backend120.batch_jobs["job-002"]["status"] == "queued" # Now go through rest of status flow, through `status()` calls for expected_status in ["queued", "running", "running", "running"]: @@ -98,9 +101,23 @@ def test_setup_simple_job_status_flow_final_per_job(self, dummy_backend, con120) assert job1.status() == "error" assert job2.status() == "finished" - def test_setup_job_start_failure(self, dummy_backend): - job = dummy_backend.connection.create_job(process_graph={}) - dummy_backend.setup_job_start_failure() + def test_setup_job_start_failure(self, dummy_backend120): + job = dummy_backend120.connection.create_job(process_graph={}) + dummy_backend120.setup_job_start_failure() with pytest.raises(OpenEoApiError, match=re.escape("[500] Internal: No job starting for you, buddy")): job.start() assert job.status() == "error" + + def test_version(self, dummy_backend120, dummy_backend130): + capabilities120 = dummy_backend120.connection.capabilities() + capabilities130 = dummy_backend130.connection.capabilities() + + assert capabilities120.api_version() == "1.2.0" + assert capabilities130.api_version() == "1.3.0" + + def test_jwt_conformance(self, dummy_backend120, dummy_backend130): + capabilities120 = dummy_backend120.connection.capabilities() + capabilities130 = dummy_backend130.connection.capabilities() + + assert capabilities120.has_conformance("https://api.openeo.org/*/authentication/jwt") == False + assert capabilities130.has_conformance("https://api.openeo.org/*/authentication/jwt") == True \ No newline at end of file