|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
| 4 | +import sys |
| 5 | +from typing import Iterable, Optional, Union |
| 6 | + |
| 7 | +import requests |
| 8 | +from requests import Response |
| 9 | +from requests.auth import AuthBase |
| 10 | + |
| 11 | +import openeo |
| 12 | +from openeo.rest import OpenEoApiError, OpenEoApiPlainError, OpenEoRestError |
| 13 | +from openeo.rest.auth.auth import NullAuth |
| 14 | +from openeo.util import ContextTimer, ensure_list, str_truncate, url_join |
| 15 | + |
| 16 | +_log = logging.getLogger(__name__) |
| 17 | + |
| 18 | +# Default timeouts for requests |
| 19 | +# TODO: get default_timeout from config? |
| 20 | +DEFAULT_TIMEOUT = 20 * 60 |
| 21 | + |
| 22 | + |
| 23 | +class RestApiConnection: |
| 24 | + """Base connection class implementing generic REST API request functionality""" |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + root_url: str, |
| 29 | + auth: Optional[AuthBase] = None, |
| 30 | + session: Optional[requests.Session] = None, |
| 31 | + default_timeout: Optional[int] = None, |
| 32 | + slow_response_threshold: Optional[float] = None, |
| 33 | + ): |
| 34 | + self._root_url = root_url |
| 35 | + self._auth = None |
| 36 | + self.auth = auth or NullAuth() |
| 37 | + self.session = session or requests.Session() |
| 38 | + self.default_timeout = default_timeout or DEFAULT_TIMEOUT |
| 39 | + self.default_headers = { |
| 40 | + "User-Agent": "openeo-python-client/{cv} {py}/{pv} {pl}".format( |
| 41 | + cv=openeo.client_version(), |
| 42 | + py=sys.implementation.name, |
| 43 | + pv=".".join(map(str, sys.version_info[:3])), |
| 44 | + pl=sys.platform, |
| 45 | + ) |
| 46 | + } |
| 47 | + self.slow_response_threshold = slow_response_threshold |
| 48 | + |
| 49 | + @property |
| 50 | + def root_url(self): |
| 51 | + return self._root_url |
| 52 | + |
| 53 | + @property |
| 54 | + def auth(self) -> Union[AuthBase, None]: |
| 55 | + return self._auth |
| 56 | + |
| 57 | + @auth.setter |
| 58 | + def auth(self, auth: Union[AuthBase, None]): |
| 59 | + self._auth = auth |
| 60 | + self._on_auth_update() |
| 61 | + |
| 62 | + def _on_auth_update(self): |
| 63 | + pass |
| 64 | + |
| 65 | + def build_url(self, path: str): |
| 66 | + return url_join(self._root_url, path) |
| 67 | + |
| 68 | + def _merged_headers(self, headers: dict) -> dict: |
| 69 | + """Merge default headers with given headers""" |
| 70 | + result = self.default_headers.copy() |
| 71 | + if headers: |
| 72 | + result.update(headers) |
| 73 | + return result |
| 74 | + |
| 75 | + def _is_external(self, url: str) -> bool: |
| 76 | + """Check if given url is external (not under root url)""" |
| 77 | + root = self.root_url.rstrip("/") |
| 78 | + return not (url == root or url.startswith(root + "/")) |
| 79 | + |
| 80 | + def request( |
| 81 | + self, |
| 82 | + method: str, |
| 83 | + path: str, |
| 84 | + *, |
| 85 | + params: Optional[dict] = None, |
| 86 | + headers: Optional[dict] = None, |
| 87 | + auth: Optional[AuthBase] = None, |
| 88 | + check_error: bool = True, |
| 89 | + expected_status: Optional[Union[int, Iterable[int]]] = None, |
| 90 | + **kwargs, |
| 91 | + ): |
| 92 | + """Generic request send""" |
| 93 | + url = self.build_url(path) |
| 94 | + # Don't send default auth headers to external domains. |
| 95 | + auth = auth or (self.auth if not self._is_external(url) else None) |
| 96 | + slow_response_threshold = kwargs.pop("slow_response_threshold", self.slow_response_threshold) |
| 97 | + if _log.isEnabledFor(logging.DEBUG): |
| 98 | + _log.debug( |
| 99 | + "Request `{m} {u}` with params {p}, headers {h}, auth {a}, kwargs {k}".format( |
| 100 | + m=method.upper(), |
| 101 | + u=url, |
| 102 | + p=params, |
| 103 | + h=headers and headers.keys(), |
| 104 | + a=type(auth).__name__, |
| 105 | + k=list(kwargs.keys()), |
| 106 | + ) |
| 107 | + ) |
| 108 | + with ContextTimer() as timer: |
| 109 | + resp = self.session.request( |
| 110 | + method=method, |
| 111 | + url=url, |
| 112 | + params=params, |
| 113 | + headers=self._merged_headers(headers), |
| 114 | + auth=auth, |
| 115 | + timeout=kwargs.pop("timeout", self.default_timeout), |
| 116 | + **kwargs, |
| 117 | + ) |
| 118 | + if slow_response_threshold and timer.elapsed() > slow_response_threshold: |
| 119 | + _log.warning( |
| 120 | + "Slow response: `{m} {u}` took {e:.2f}s (>{t:.2f}s)".format( |
| 121 | + m=method.upper(), u=str_truncate(url, width=64), e=timer.elapsed(), t=slow_response_threshold |
| 122 | + ) |
| 123 | + ) |
| 124 | + if _log.isEnabledFor(logging.DEBUG): |
| 125 | + _log.debug( |
| 126 | + f"openEO request `{resp.request.method} {resp.request.path_url}` -> response {resp.status_code} headers {resp.headers!r}" |
| 127 | + ) |
| 128 | + # Check for API errors and unexpected HTTP status codes as desired. |
| 129 | + status = resp.status_code |
| 130 | + expected_status = ensure_list(expected_status) if expected_status else [] |
| 131 | + if check_error and status >= 400 and status not in expected_status: |
| 132 | + self._raise_api_error(resp) |
| 133 | + if expected_status and status not in expected_status: |
| 134 | + raise OpenEoRestError( |
| 135 | + "Got status code {s!r} for `{m} {p}` (expected {e!r}) with body {body}".format( |
| 136 | + m=method.upper(), p=path, s=status, e=expected_status, body=resp.text |
| 137 | + ) |
| 138 | + ) |
| 139 | + return resp |
| 140 | + |
| 141 | + def _raise_api_error(self, response: requests.Response): |
| 142 | + """Convert API error response to Python exception""" |
| 143 | + status_code = response.status_code |
| 144 | + try: |
| 145 | + info = response.json() |
| 146 | + except Exception: |
| 147 | + info = None |
| 148 | + |
| 149 | + # Valid JSON object with "code" and "message" fields indicates a proper openEO API error. |
| 150 | + if isinstance(info, dict): |
| 151 | + error_code = info.get("code") |
| 152 | + error_message = info.get("message") |
| 153 | + if error_code and isinstance(error_code, str) and error_message and isinstance(error_message, str): |
| 154 | + raise OpenEoApiError( |
| 155 | + http_status_code=status_code, |
| 156 | + code=error_code, |
| 157 | + message=error_message, |
| 158 | + id=info.get("id"), |
| 159 | + url=info.get("url"), |
| 160 | + ) |
| 161 | + |
| 162 | + # Failed to parse it as a compliant openEO API error: show body as-is in the exception. |
| 163 | + text = response.text |
| 164 | + error_message = None |
| 165 | + _log.warning(f"Failed to parse API error response: [{status_code}] {text!r} (headers: {response.headers})") |
| 166 | + |
| 167 | + # TODO: eliminate this VITO-backend specific error massaging? |
| 168 | + if status_code == 502 and "Proxy Error" in text: |
| 169 | + error_message = ( |
| 170 | + "Received 502 Proxy Error." |
| 171 | + " This typically happens when a synchronous openEO processing request takes too long and is aborted." |
| 172 | + " Consider using a batch job instead." |
| 173 | + ) |
| 174 | + |
| 175 | + raise OpenEoApiPlainError(message=text, http_status_code=status_code, error_message=error_message) |
| 176 | + |
| 177 | + def get( |
| 178 | + self, |
| 179 | + path: str, |
| 180 | + *, |
| 181 | + params: Optional[dict] = None, |
| 182 | + stream: bool = False, |
| 183 | + auth: Optional[AuthBase] = None, |
| 184 | + **kwargs, |
| 185 | + ) -> Response: |
| 186 | + """ |
| 187 | + Do GET request to REST API. |
| 188 | +
|
| 189 | + :param path: API path (without root url) |
| 190 | + :param params: Additional query parameters |
| 191 | + :param stream: True if the get request should be streamed, else False |
| 192 | + :param auth: optional custom authentication to use instead of the default one |
| 193 | + :return: response: Response |
| 194 | + """ |
| 195 | + return self.request("get", path=path, params=params, stream=stream, auth=auth, **kwargs) |
| 196 | + |
| 197 | + def post(self, path: str, json: Optional[dict] = None, **kwargs) -> Response: |
| 198 | + """ |
| 199 | + Do POST request to REST API. |
| 200 | +
|
| 201 | + :param path: API path (without root url) |
| 202 | + :param json: Data (as dictionary) to be posted with JSON encoding) |
| 203 | + :return: response: Response |
| 204 | + """ |
| 205 | + return self.request("post", path=path, json=json, allow_redirects=False, **kwargs) |
| 206 | + |
| 207 | + def delete(self, path: str, **kwargs) -> Response: |
| 208 | + """ |
| 209 | + Do DELETE request to REST API. |
| 210 | +
|
| 211 | + :param path: API path (without root url) |
| 212 | + :return: response: Response |
| 213 | + """ |
| 214 | + return self.request("delete", path=path, allow_redirects=False, **kwargs) |
| 215 | + |
| 216 | + def patch(self, path: str, **kwargs) -> Response: |
| 217 | + """ |
| 218 | + Do PATCH request to REST API. |
| 219 | +
|
| 220 | + :param path: API path (without root url) |
| 221 | + :return: response: Response |
| 222 | + """ |
| 223 | + return self.request("patch", path=path, allow_redirects=False, **kwargs) |
| 224 | + |
| 225 | + def put(self, path: str, headers: Optional[dict] = None, data: Optional[dict] = None, **kwargs) -> Response: |
| 226 | + """ |
| 227 | + Do PUT request to REST API. |
| 228 | +
|
| 229 | + :param path: API path (without root url) |
| 230 | + :param headers: headers that gets added to the request. |
| 231 | + :param data: data that gets added to the request. |
| 232 | + :return: response: Response |
| 233 | + """ |
| 234 | + return self.request("put", path=path, data=data, headers=headers, allow_redirects=False, **kwargs) |
| 235 | + |
| 236 | + def __repr__(self): |
| 237 | + return "<{c} to {r!r} with {a}>".format(c=type(self).__name__, r=self._root_url, a=type(self.auth).__name__) |
0 commit comments