Skip to content

Commit b7a0bca

Browse files
committed
Refactor internals of openapi class
This will mask attributes of the class private and add a bit of abstraction from requests.
1 parent 76ede0c commit b7a0bca

File tree

12 files changed

+614
-331
lines changed

12 files changed

+614
-331
lines changed

lint_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mypy==1.19.0
77
shellcheck-py==0.11.0.1
88

99
# Type annotation stubs
10+
types-aiofiles
1011
types-pygments
1112
types-PyYAML
1213
types-requests

lower_bounds_constraints.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
aiofiles==25.1.0
2+
aiohttp==3.12.0
13
click==8.0.0
24
packaging==20.0
35
PyYAML==5.3

pulp-glue/pulp_glue/common/authentication.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,158 @@
44
import requests
55

66

7+
class AuthProviderBase:
8+
"""
9+
Base class for auth providers.
10+
11+
This abstract base class will analyze the authentication proposals of the openapi specs.
12+
Different authentication schemes should be implemented by subclasses.
13+
Returned auth objects need to be compatible with `requests.auth.AuthBase`.
14+
"""
15+
16+
def __init__(self) -> None:
17+
self._oauth2_token: str | None = None
18+
self._oauth2_expires: datetime = datetime.now()
19+
20+
def can_complete_http_basic(self) -> bool:
21+
return False
22+
23+
def can_complete_mutualTLS(self) -> bool:
24+
return False
25+
26+
def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool:
27+
return False
28+
29+
def can_complete_scheme(self, scheme: dict[str, t.Any], scopes: list[str]) -> bool:
30+
if scheme["type"] == "http":
31+
if scheme["scheme"] == "basic":
32+
return self.can_complete_http_basic()
33+
elif scheme["type"] == "mutualTLS":
34+
return self.can_complete_mutualTLS()
35+
elif scheme["type"] == "oauth2":
36+
for flow_name, flow in scheme["flows"].items():
37+
if (
38+
flow_name == "clientCredentials"
39+
and self.can_complete_oauth2_client_credentials(flow["scopes"])
40+
):
41+
return True
42+
return False
43+
44+
def can_complete(
45+
self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]]
46+
) -> bool:
47+
for name, scopes in proposal.items():
48+
scheme = security_schemes.get(name)
49+
if scheme is None or not self.can_complete_scheme(scheme, scopes):
50+
return False
51+
# This covers the case where `[]` allows for no auth at all.
52+
return True
53+
54+
async def auth_success_hook(
55+
self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]]
56+
) -> None:
57+
pass
58+
59+
async def auth_failure_hook(
60+
self, proposal: dict[str, list[str]], security_schemes: dict[str, dict[str, t.Any]]
61+
) -> bool:
62+
"""
63+
Hook called on failed authentication.
64+
Return True for retrying.
65+
"""
66+
return False
67+
68+
async def http_basic_credentials(self) -> tuple[bytes, bytes]:
69+
raise NotImplementedError()
70+
71+
async def oauth2_client_credentials(self) -> tuple[bytes, bytes]:
72+
raise NotImplementedError()
73+
74+
def tls_credentials(self) -> tuple[str, str | None]:
75+
raise NotImplementedError()
76+
77+
78+
class BasicAuthProvider(AuthProviderBase):
79+
"""
80+
AuthProvider providing basic auth with fixed `username`, `password`.
81+
"""
82+
83+
def __init__(self, username: t.AnyStr, password: t.AnyStr):
84+
super().__init__()
85+
self.username: bytes = username.encode("latin1") if isinstance(username, str) else username
86+
self.password: bytes = password.encode("latin1") if isinstance(password, str) else password
87+
88+
def can_complete_http_basic(self) -> bool:
89+
return True
90+
91+
async def http_basic_credentials(self) -> tuple[bytes, bytes]:
92+
return self.username, self.password
93+
94+
95+
class GlueAuthProvider(AuthProviderBase):
96+
"""
97+
AuthProvider allowing to be used with prepared credentials.
98+
"""
99+
100+
def __init__(
101+
self,
102+
*,
103+
username: t.AnyStr | None = None,
104+
password: t.AnyStr | None = None,
105+
client_id: t.AnyStr | None = None,
106+
client_secret: t.AnyStr | None = None,
107+
cert: str | None = None,
108+
key: str | None = None,
109+
):
110+
super().__init__()
111+
self.username: bytes | None = None
112+
self.password: bytes | None = None
113+
self.client_id: bytes | None = None
114+
self.client_secret: bytes | None = None
115+
self.cert: str | None = cert
116+
self.key: str | None = key
117+
118+
if username is not None:
119+
assert password is not None
120+
self.username = username.encode("latin1") if isinstance(username, str) else username
121+
self.password = password.encode("latin1") if isinstance(password, str) else password
122+
if client_id is not None:
123+
assert client_secret is not None
124+
self.client_id = client_id.encode("latin1") if isinstance(client_id, str) else client_id
125+
self.client_secret = (
126+
client_secret.encode("latin1") if isinstance(client_secret, str) else client_secret
127+
)
128+
129+
if cert is None and key is not None:
130+
raise RuntimeError("Key can only be used together with a cert.")
131+
132+
def can_complete_http_basic(self) -> bool:
133+
return self.username is not None
134+
135+
def can_complete_oauth2_client_credentials(self, scopes: list[str]) -> bool:
136+
return self.client_id is not None
137+
138+
def can_complete_mutualTLS(self) -> bool:
139+
return self.cert is not None
140+
141+
async def http_basic_credentials(self) -> tuple[bytes, bytes]:
142+
assert self.username is not None
143+
assert self.password is not None
144+
return self.username, self.password
145+
146+
async def oauth2_client_credentials(self) -> tuple[bytes, bytes]:
147+
assert self.client_id is not None
148+
assert self.client_secret is not None
149+
return self.client_id, self.client_secret
150+
151+
def tls_credentials(self) -> tuple[str, str | None]:
152+
assert self.cert is not None
153+
return (self.cert, self.key)
154+
155+
156+
# ----------------------8<----8<------------------------
157+
158+
7159
class OAuth2ClientCredentialsAuth(requests.auth.AuthBase):
8160
"""
9161
This implements the OAuth2 ClientCredentials Grant authentication flow.

pulp-glue/pulp_glue/common/context.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from packaging.specifiers import SpecifierSet
1111

12+
from pulp_glue.common.authentication import GlueAuthProvider
1213
from pulp_glue.common.exceptions import (
1314
NotImplementedFake,
1415
OpenAPIError,
@@ -19,7 +20,7 @@
1920
UnsafeCallError,
2021
)
2122
from pulp_glue.common.i18n import get_translation
22-
from pulp_glue.common.openapi import BasicAuthProvider, OpenAPI
23+
from pulp_glue.common.openapi import OpenAPI
2324

2425
if sys.version_info >= (3, 11):
2526
import tomllib
@@ -335,8 +336,13 @@ def from_config(cls, config: dict[str, t.Any]) -> "t.Self":
335336
api_kwargs: dict[str, t.Any] = {
336337
"base_url": config["base_url"],
337338
}
338-
if "username" in config:
339-
api_kwargs["auth_provider"] = BasicAuthProvider(config["username"], config["password"])
339+
api_kwargs["auth_provider"] = GlueAuthProvider(
340+
**{
341+
k: v
342+
for k, v in config.items()
343+
if k in {"username", "password", "client_id", "client_secret", "cert", "key"}
344+
}
345+
)
340346
if "headers" in config:
341347
api_kwargs["headers"] = dict(
342348
(header.split(":", maxsplit=1) for header in config["headers"])
@@ -385,7 +391,9 @@ def api(self) -> OpenAPI:
385391
# Deprecated for 'auth'.
386392
if not password:
387393
password = self.prompt("password", hide_input=True)
388-
self._api_kwargs["auth_provider"] = BasicAuthProvider(username, password)
394+
self._api_kwargs["auth_provider"] = GlueAuthProvider(
395+
username=username, password=password
396+
)
389397
warnings.warn(
390398
"Using 'username' and 'password' with 'PulpContext' is deprecated. "
391399
"Use an auth provider with the 'auth_provider' argument instead.",
@@ -399,10 +407,10 @@ def api(self) -> OpenAPI:
399407
)
400408
except OpenAPIError as e:
401409
raise PulpException(str(e))
410+
self._patch_api_spec()
402411
# Rerun scheduled version checks
403412
for plugin_requirement in self._needed_plugins:
404413
self.needs_plugin(plugin_requirement)
405-
self._patch_api_spec()
406414
return self._api
407415

408416
@property

0 commit comments

Comments
 (0)