Skip to content

Commit 8edb776

Browse files
committed
Merge branch 'issue254-Clear-capabilities-cache-on-log-in'
2 parents ca1bde1 + 315ce7a commit 8edb776

File tree

4 files changed

+122
-1
lines changed

4 files changed

+122
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626

2727
### Fixed
2828

29+
- Clear capabilities cache on login ([#254](https://github.com/Open-EO/openeo-python-client/issues/254))
30+
2931

3032
## [0.36.0] - 2024-12-10
3133

openeo/rest/connection.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(
113113
slow_response_threshold: Optional[float] = None,
114114
):
115115
self._root_url = root_url
116+
self._auth = None
116117
self.auth = auth or NullAuth()
117118
self.session = session or requests.Session()
118119
self.default_timeout = default_timeout or DEFAULT_TIMEOUT
@@ -129,6 +130,18 @@ def __init__(
129130
def root_url(self):
130131
return self._root_url
131132

133+
@property
134+
def auth(self) -> Union[AuthBase, None]:
135+
return self._auth
136+
137+
@auth.setter
138+
def auth(self, auth: Union[AuthBase, None]):
139+
self._auth = auth
140+
self._on_auth_update()
141+
142+
def _on_auth_update(self):
143+
pass
144+
132145
def build_url(self, path: str):
133146
return url_join(self._root_url, path)
134147

@@ -340,12 +353,12 @@ def __init__(
340353
if "://" not in url:
341354
url = "https://" + url
342355
self._orig_url = url
356+
self._capabilities_cache = LazyLoadCache()
343357
super().__init__(
344358
root_url=self.version_discovery(url, session=session, timeout=default_timeout),
345359
auth=auth, session=session, default_timeout=default_timeout,
346360
slow_response_threshold=slow_response_threshold,
347361
)
348-
self._capabilities_cache = LazyLoadCache()
349362

350363
# Initial API version check.
351364
self._api_version.require_at_least(self._MINIMUM_API_VERSION)
@@ -380,6 +393,10 @@ def version_discovery(
380393
# Be very lenient about failing on the well-known URI strategy.
381394
return url
382395

396+
def _on_auth_update(self):
397+
super()._on_auth_update()
398+
self._capabilities_cache.clear()
399+
383400
def _get_auth_config(self) -> AuthConfig:
384401
if self._auth_config is None:
385402
self._auth_config = AuthConfig()

openeo/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,9 @@ def get(self, key: Union[str, tuple], load: Callable[[], Any]):
476476
self._cache[key] = load()
477477
return self._cache[key]
478478

479+
def clear(self):
480+
self._cache = {}
481+
479482

480483
def str_truncate(text: str, width: int = 64, ellipsis: str = "...") -> str:
481484
"""Shorten a string (with an ellipsis) if it is longer than certain length."""

tests/rest/test_connection.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
API_URL = "https://oeo.test/"
5151

52+
# TODO: eliminate this and replace with `build_capabilities` usage
5253
BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}]
5354

5455

@@ -552,6 +553,104 @@ def test_capabilities_caching(requests_mock):
552553
assert m.call_count == 1
553554

554555

556+
def _get_capabilities_auth_dependent(request, context):
557+
capabilities = build_capabilities()
558+
capabilities["endpoints"] = [
559+
{"methods": ["GET"], "path": "/credentials/basic"},
560+
{"methods": ["GET"], "path": "/credentials/oidc"},
561+
]
562+
if "Authorization" in request.headers:
563+
capabilities["endpoints"].append({"methods": ["GET"], "path": "/me"})
564+
return capabilities
565+
566+
567+
def test_capabilities_caching_after_authenticate_basic(requests_mock):
568+
user, pwd = "john262", "J0hndo3"
569+
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
570+
requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd))
571+
572+
con = Connection(API_URL)
573+
assert con.capabilities().capabilities["endpoints"] == [
574+
{"methods": ["GET"], "path": "/credentials/basic"},
575+
{"methods": ["GET"], "path": "/credentials/oidc"},
576+
]
577+
assert get_capabilities_mock.call_count == 1
578+
con.capabilities()
579+
assert get_capabilities_mock.call_count == 1
580+
581+
con.authenticate_basic(username=user, password=pwd)
582+
assert get_capabilities_mock.call_count == 1
583+
assert con.capabilities().capabilities["endpoints"] == [
584+
{"methods": ["GET"], "path": "/credentials/basic"},
585+
{"methods": ["GET"], "path": "/credentials/oidc"},
586+
{"methods": ["GET"], "path": "/me"},
587+
]
588+
589+
assert get_capabilities_mock.call_count == 2
590+
591+
592+
def test_capabilities_caching_after_authenticate_oidc_refresh_token(requests_mock):
593+
client_id = "myclient"
594+
refresh_token = "fr65h!"
595+
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
596+
requests_mock.get(
597+
API_URL + "credentials/oidc",
598+
json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]},
599+
)
600+
oidc_mock = OidcMock(
601+
requests_mock=requests_mock,
602+
expected_grant_type="refresh_token",
603+
expected_client_id=client_id,
604+
expected_fields={"refresh_token": refresh_token},
605+
)
606+
607+
conn = Connection(API_URL)
608+
assert conn.capabilities().capabilities["endpoints"] == [
609+
{"methods": ["GET"], "path": "/credentials/basic"},
610+
{"methods": ["GET"], "path": "/credentials/oidc"},
611+
]
612+
613+
assert get_capabilities_mock.call_count == 1
614+
conn.capabilities()
615+
assert get_capabilities_mock.call_count == 1
616+
617+
conn.authenticate_oidc_refresh_token(client_id=client_id, refresh_token=refresh_token)
618+
assert get_capabilities_mock.call_count == 1
619+
assert conn.capabilities().capabilities["endpoints"] == [
620+
{"methods": ["GET"], "path": "/credentials/basic"},
621+
{"methods": ["GET"], "path": "/credentials/oidc"},
622+
{"methods": ["GET"], "path": "/me"},
623+
]
624+
assert get_capabilities_mock.call_count == 2
625+
626+
627+
def test_capabilities_caching_after_authenticate_oidc_access_token(requests_mock):
628+
get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent)
629+
requests_mock.get(
630+
API_URL + "credentials/oidc",
631+
json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]},
632+
)
633+
634+
conn = Connection(API_URL)
635+
assert conn.capabilities().capabilities["endpoints"] == [
636+
{"methods": ["GET"], "path": "/credentials/basic"},
637+
{"methods": ["GET"], "path": "/credentials/oidc"},
638+
]
639+
640+
assert get_capabilities_mock.call_count == 1
641+
conn.capabilities()
642+
assert get_capabilities_mock.call_count == 1
643+
644+
conn.authenticate_oidc_access_token(access_token="6cc355!")
645+
assert get_capabilities_mock.call_count == 1
646+
assert conn.capabilities().capabilities["endpoints"] == [
647+
{"methods": ["GET"], "path": "/credentials/basic"},
648+
{"methods": ["GET"], "path": "/credentials/oidc"},
649+
{"methods": ["GET"], "path": "/me"},
650+
]
651+
assert get_capabilities_mock.call_count == 2
652+
653+
555654
def test_file_formats(requests_mock):
556655
requests_mock.get("https://oeo.test/", json={"api_version": "1.0.0"})
557656
m = requests_mock.get("https://oeo.test/file_formats", json={"output": {"GTiff": {"gis_data_types": ["raster"]}}})

0 commit comments

Comments
 (0)