Skip to content

Commit 12209ea

Browse files
committed
Add request & response abstraction
1 parent cdf6e71 commit 12209ea

File tree

2 files changed

+206
-124
lines changed

2 files changed

+206
-124
lines changed

pulp-glue/pulp_glue/common/openapi.py

Lines changed: 123 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# copyright (c) 2020, Matthias Dellweg
2-
# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt)
3-
41
import json
52
import logging
63
import os
@@ -13,13 +10,12 @@
1310

1411
import requests
1512
import urllib3
16-
from multidict import CIMultiDict, MutableMultiMapping
13+
from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping
1714

1815
from pulp_glue.common import __version__
1916
from pulp_glue.common.exceptions import (
2017
OpenAPIError,
2118
PulpAuthenticationFailed,
22-
PulpException,
2319
PulpHTTPError,
2420
PulpNotAutorized,
2521
UnsafeCallError,
@@ -37,10 +33,22 @@
3733
SAFE_METHODS = ["GET", "HEAD", "OPTIONS"]
3834

3935

36+
@dataclass
37+
class _Request:
38+
operation_id: str
39+
method: str
40+
url: str
41+
headers: MutableMultiMapping[str] | CIMultiDictProxy[str] | t.MutableMapping[str, str]
42+
params: dict[str, str] | None = None
43+
data: dict[str, t.Any] | str | None = None
44+
files: dict[str, tuple[str, UploadType, str]] | None = None
45+
security: list[dict[str, list[str]]] | None = None
46+
47+
4048
@dataclass
4149
class _Response:
4250
status_code: int
43-
headers: MutableMultiMapping[str] | t.MutableMapping[str, str]
51+
headers: MutableMultiMapping[str] | CIMultiDictProxy[str] | t.MutableMapping[str, str]
4452
body: bytes
4553

4654

@@ -149,7 +157,7 @@ class OpenAPI:
149157
auth_provider: Object that returns requests auth objects according to the api spec.
150158
cert: Client certificate used for auth.
151159
key: Matching key for `cert` if not already included.
152-
verify_ssl: Whether to check server TLS certificates agains a CA (requests semantic).
160+
verify_ssl: Whether to check server TLS certificates agains a CA.
153161
refresh_cache: Whether to fetch the api doc regardless.
154162
dry_run: Flag to disallow issuing POST, PUT, PATCH or DELETE calls.
155163
debug_callback: Callback that will be called with strings useful for logging or debugging.
@@ -163,7 +171,7 @@ def __init__(
163171
self,
164172
base_url: str,
165173
doc_path: str,
166-
headers: CIMultiDict[str] | None = None,
174+
headers: CIMultiDict[str] | CIMultiDictProxy[str] | None = None,
167175
auth_provider: AuthProviderBase | None = None,
168176
cert: str | None = None,
169177
key: str | None = None,
@@ -178,11 +186,15 @@ def __init__(
178186
):
179187
if validate_certs is not None:
180188
warnings.warn(
181-
"validate_certs is deprecated; use verify_ssl instead.", DeprecationWarning
189+
"validate_certs is deprecated; use verify_ssl instead.",
190+
DeprecationWarning,
182191
)
183192
verify_ssl = validate_certs
184193
if safe_calls_only is not None:
185-
warnings.warn("safe_calls_only is deprecated; use dry_run instead.", DeprecationWarning)
194+
warnings.warn(
195+
"safe_calls_only is deprecated; use dry_run instead.",
196+
DeprecationWarning,
197+
)
186198
dry_run = safe_calls_only
187199
if debug_callback is not None:
188200
warnings.warn(
@@ -281,7 +293,7 @@ def _parse_api(self, data: bytes) -> None:
281293
raise OpenAPIError(_("Unknown schema version"))
282294
self.operations: dict[str, t.Any] = {
283295
method_entry["operationId"]: (method, path)
284-
for path, path_entry in self.api_spec["paths"].items()
296+
for path, path_entry in self.api_spec.get("paths", {}).items()
285297
for method, method_entry in path_entry.items()
286298
if method in {"get", "put", "post", "delete", "options", "head", "patch", "trace"}
287299
}
@@ -398,7 +410,7 @@ def _render_request_body(
398410
) -> tuple[
399411
str | None,
400412
dict[str, t.Any] | str | None,
401-
list[tuple[str, tuple[str, UploadType, str]]] | None,
413+
dict[str, tuple[str, UploadType, str]] | None,
402414
]:
403415
content_types: list[str] = []
404416
try:
@@ -417,7 +429,7 @@ def _render_request_body(
417429

418430
content_type: str | None = None
419431
data: dict[str, t.Any] | str | None = None
420-
files: list[tuple[str, tuple[str, UploadType, str]]] | None = None
432+
files: dict[str, tuple[str, UploadType, str]] | None = None
421433

422434
candidate_content_types = [
423435
"multipart/form-data",
@@ -455,21 +467,20 @@ def _render_request_body(
455467
elif content_type.startswith("application/x-www-form-urlencoded"):
456468
data = body
457469
elif content_type.startswith("multipart/form-data"):
458-
uploads: dict[str, tuple[str, UploadType, str]] = {}
459470
data = {}
471+
files = {}
460472
# Extract and prepare the files to upload
461473
if body:
462474
for key, value in body.items():
463475
if isinstance(value, (bytes, BufferedReader)):
464-
uploads[key] = (
465-
getattr(value, "name", key),
476+
# If available, use the filename.
477+
files[key] = (
478+
getattr(value, "name", key).split("/")[-1],
466479
value,
467480
"application/octet-stream",
468481
)
469482
else:
470483
data[key] = value
471-
if uploads:
472-
files = [(key, upload_data) for key, upload_data in uploads.items()]
473484
break
474485
else:
475486
# No known content-type left
@@ -485,7 +496,7 @@ def _render_request_body(
485496

486497
return content_type, data, files
487498

488-
def _send_request(
499+
def _render_request(
489500
self,
490501
path_spec: dict[str, t.Any],
491502
method: str,
@@ -494,83 +505,120 @@ def _send_request(
494505
headers: dict[str, str],
495506
body: dict[str, t.Any] | None = None,
496507
validate_body: bool = True,
497-
) -> _Response:
508+
) -> _Request:
498509
method_spec = path_spec[method]
510+
_headers = CIMultiDict(self._headers)
511+
_headers.update(headers)
512+
513+
security: list[dict[str, list[str]]] | None
514+
if self._auth_provider and "Authorization" not in self._headers:
515+
security = method_spec.get("security", self.api_spec.get("security"))
516+
else:
517+
# No auth required? Don't provide it.
518+
# No auth_provider available? Hope for the best (should do the trick for cert auth).
519+
# Authorization header present? You wanted it that way...
520+
security = None
521+
499522
content_type, data, files = self._render_request_body(method_spec, body, validate_body)
500-
security: list[dict[str, list[str]]] | None = method_spec.get(
501-
"security", self.api_spec.get("security")
523+
# For we encode the json on our side.
524+
# Somehow this does not work properly for multipart...
525+
if content_type is not None and content_type.startswith("application/json"):
526+
_headers["Content-Type"] = content_type
527+
528+
return _Request(
529+
operation_id=method_spec["operationId"],
530+
method=method,
531+
url=url,
532+
headers=_headers,
533+
params=params,
534+
data=data,
535+
files=files,
536+
security=security,
502537
)
503-
if security and self._auth_provider:
538+
539+
def _log_request(self, request: _Request) -> None:
540+
if request.params:
541+
qs = urlencode(request.params)
542+
self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}?{qs}")
543+
self._debug_callback(
544+
2,
545+
"\n".join([f" {key}=={value}" for key, value in request.params.items()]),
546+
)
547+
else:
548+
self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}")
549+
for key, value in request.headers.items():
550+
self._debug_callback(2, f" {key}: {value}")
551+
if request.data is not None:
552+
self._debug_callback(3, f"{request.data!r}")
553+
if request.files is not None:
554+
for key, (name, _dummy, content_type) in request.files.items():
555+
self._debug_callback(3, f"{key} <- {name} [{content_type}]")
556+
557+
def _send_request(
558+
self,
559+
request: _Request,
560+
) -> _Response:
561+
# This function uses requests to translate the _Request into a _Response.
562+
if request.security and self._auth_provider:
504563
if "Authorization" in self._session.headers:
505564
# Bad idea, but you wanted it that way.
506565
auth = None
507566
else:
508-
auth = self._auth_provider(security, self.api_spec["components"]["securitySchemes"])
567+
auth = self._auth_provider(
568+
request.security, self.api_spec["components"]["securitySchemes"]
569+
)
509570
else:
510571
# No auth required? Don't provide it.
511572
# No auth_provider available? Hope for the best (should do the trick for cert auth).
512573
auth = None
513-
# For we encode the json on our side.
514-
# Somehow this does not work properly for multipart...
515-
if content_type is not None and content_type.startswith("application/json"):
516-
headers["content-type"] = content_type
517-
request = self._session.prepare_request(
518-
requests.Request(
519-
method,
520-
url,
574+
try:
575+
r = self._session.request(
576+
request.method,
577+
request.url,
578+
params=request.params,
579+
headers=request.headers,
580+
data=request.data,
581+
files=request.files,
521582
auth=auth,
522-
params=params,
523-
headers=headers,
524-
data=data,
525-
files=files,
526-
)
527-
)
528-
if content_type:
529-
assert request.headers["content-type"].startswith(content_type), (
530-
f"{request.headers['content-type']} != {content_type}"
583+
allow_redirects=False,
531584
)
532-
for key, value in request.headers.items():
533-
self._debug_callback(2, f" {key}: {value}")
534-
if request.body is not None:
535-
self._debug_callback(3, f"{request.body!r}")
536-
if self._dry_run and method.upper() not in SAFE_METHODS:
537-
raise UnsafeCallError(_("Call aborted due to safe mode"))
538-
try:
539-
response = self._session.send(request)
585+
response = _Response(status_code=r.status_code, headers=r.headers, body=r.content)
540586
except requests.TooManyRedirects as e:
587+
# We could handle that in the middleware...
541588
assert e.response is not None
542589
raise OpenAPIError(
543-
_("Received redirect to '{url}'. Please check your CLI configuration.").format(
544-
url=e.response.headers["location"]
590+
_(
591+
"Received redirect to '{new_url} from {old_url}'."
592+
" Please check your configuration."
593+
).format(
594+
new_url=e.response.headers["location"],
595+
old_url=request.url,
545596
)
546597
)
547598
except requests.RequestException as e:
548599
raise OpenAPIError(str(e))
600+
601+
return response
602+
603+
def _log_response(self, response: _Response) -> None:
549604
self._debug_callback(
550605
1, _("Response: {status_code}").format(status_code=response.status_code)
551606
)
552607
for key, value in response.headers.items():
553608
self._debug_callback(2, f" {key}: {value}")
554-
if response.text:
555-
self._debug_callback(3, f"{response.text}")
609+
if response.body:
610+
self._debug_callback(3, f"{response.body!r}")
611+
612+
def _parse_response(self, method_spec: dict[str, t.Any], response: _Response) -> t.Any:
556613
if "Correlation-Id" in response.headers:
557614
self._set_correlation_id(response.headers["Correlation-Id"])
558615
if response.status_code == 401:
559616
raise PulpAuthenticationFailed(method_spec["operationId"])
560-
if response.status_code == 403:
617+
elif response.status_code == 403:
561618
raise PulpNotAutorized(method_spec["operationId"])
562-
try:
563-
response.raise_for_status()
564-
except requests.HTTPError as e:
565-
if e.response is not None:
566-
raise PulpHTTPError(str(e.response.text), e.response.status_code)
567-
else:
568-
raise PulpException(str(e))
569-
return _Response(
570-
status_code=response.status_code, headers=response.headers, body=response.content
571-
)
619+
elif response.status_code >= 300:
620+
raise PulpHTTPError(response.body.decode(), response.status_code)
572621

573-
def _parse_response(self, method_spec: dict[str, t.Any], response: _Response) -> t.Any:
574622
if response.status_code == 204:
575623
return {}
576624

@@ -613,8 +661,8 @@ def call(
613661
The JSON decoded server response if any.
614662
615663
Raises:
616-
OpenAPIValidationError: on failed input validation (no request was sent to the server).
617-
requests.HTTPError: on failures related to the HTTP call made.
664+
ValidationError: on failed input validation (no request was sent to the server).
665+
OpenAPIError: on failures related to the HTTP call made.
618666
"""
619667
method, path = self.operations[operation_id]
620668
path_spec = self.api_spec["paths"][path]
@@ -643,17 +691,7 @@ def call(
643691
)
644692
url = urljoin(self._base_url, path)
645693

646-
if query_params:
647-
qs = urlencode(query_params)
648-
log_msg = f"{operation_id} : {method} {url}?{qs}"
649-
else:
650-
log_msg = f"{operation_id} : {method} {url}"
651-
self._debug_callback(1, log_msg)
652-
self._debug_callback(
653-
2, "\n".join([f" {key}=={value}" for key, value in query_params.items()])
654-
)
655-
656-
response = self._send_request(
694+
request = self._render_request(
657695
path_spec,
658696
method,
659697
url,
@@ -662,5 +700,12 @@ def call(
662700
body,
663701
validate_body=validate_body,
664702
)
703+
self._log_request(request)
704+
705+
if self._dry_run and request.method.upper() not in SAFE_METHODS:
706+
raise UnsafeCallError(_("Call aborted due to safe mode"))
707+
708+
response = self._send_request(request)
665709

710+
self._log_response(response)
666711
return self._parse_response(method_spec, response)

0 commit comments

Comments
 (0)