-
Notifications
You must be signed in to change notification settings - Fork 10
feat: allow multiple credentials for input connector #938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3cda828
0e42e7e
ab3ca57
497c044
506f308
d10a7d9
74a664a
fa5af82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,3 +32,4 @@ requirements.* | |
| /scripts/ | ||
| .dmypy.json | ||
| tmp | ||
| .envrc | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -163,7 +163,7 @@ def from_target(cls, target_url: str) -> "Credentials | None": | |
| return credentials | ||
|
|
||
| @classmethod | ||
| def from_endpoint(cls, target_endpoint: str) -> "Credentials | None": | ||
| def from_endpoint(cls, target_endpoint: str) -> "list[Credentials] | Credentials | None": | ||
| """Factory method to create a credentials object based on the credentials stored in the | ||
| environment variable :code:`LOGPREP_CREDENTIALS_FILE`. | ||
| Based on these credentials the expected authentication method is chosen and represented | ||
|
|
@@ -187,8 +187,14 @@ def from_endpoint(cls, target_endpoint: str) -> "Credentials | None": | |
| endpoint_credentials = credentials_file.input.get("endpoints") | ||
| if endpoint_credentials is None: | ||
| return None | ||
| credential_mapping: dict | None = endpoint_credentials.get(target_endpoint) | ||
| credentials = cls.from_dict(credential_mapping) | ||
| credential_mapping: list | dict | None = endpoint_credentials.get(target_endpoint) | ||
|
|
||
| credentials: list[Credentials] | Credentials | None = None | ||
| if isinstance(credential_mapping, dict): | ||
| credentials = cls.from_dict(credential_mapping) | ||
| elif isinstance(credential_mapping, list): | ||
| credentials = cls.from_list(credential_mapping) | ||
|
|
||
| return credentials | ||
|
|
||
| @staticmethod | ||
|
|
@@ -250,6 +256,16 @@ def _resolve_secret_content(credential_mapping: dict): | |
| credential_mapping.pop(f"{credential_type}_file") | ||
| credential_mapping.update(secret_content) | ||
|
|
||
| @classmethod | ||
| def from_list(cls, credential_mapping: list[dict | None]) -> "list[Credentials] | None": | ||
| creds: list[Credentials] = [] | ||
| for credential in credential_mapping: | ||
| cred = cls.from_dict(credential) | ||
| if isinstance(cred, Credentials): | ||
| creds.append(cred) | ||
|
|
||
| return creds if len(creds) > 0 else None | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, credential_mapping: dict | None) -> "Credentials | None": | ||
| """matches the given credentials of the credentials mapping | ||
|
|
@@ -392,10 +408,10 @@ class AccessToken: | |
| token: str = field(validator=validators.instance_of(str), repr=False) | ||
| """token used for authentication against the target""" | ||
| expiry_time: datetime = field( | ||
| validator=validators.instance_of((datetime, type(None))), init=False | ||
| validator=validators.instance_of(datetime), init=False, default=datetime.now() | ||
| ) | ||
| """time when token is expired""" | ||
| refresh_token: str = field( | ||
| refresh_token: str | None = field( | ||
| validator=validators.instance_of((str, type(None))), default=None, repr=False | ||
| ) | ||
| """is used incase the token is expired""" | ||
|
|
@@ -426,7 +442,9 @@ class Credentials: | |
|
|
||
| _logger = logging.getLogger("Credentials") | ||
|
|
||
| _session: Session = field(validator=validators.instance_of((Session, type(None))), default=None) | ||
| _session: Session | None = field( | ||
| validator=validators.instance_of((Session, type(None))), default=None | ||
| ) | ||
|
|
||
| def get_session(self): | ||
| """returns session with retry configuration""" | ||
|
|
@@ -548,14 +566,14 @@ class OAuth2PasswordFlowCredentials(Credentials): | |
| """the username for the token request""" | ||
| timeout: int = field(validator=validators.instance_of(int), default=1) | ||
| """The timeout for the token request. Defaults to 1 second.""" | ||
| client_id: str = field(validator=validators.instance_of((str, type(None))), default=None) | ||
| client_id: str | None = field(validator=validators.instance_of((str, type(None))), default=None) | ||
| """The client id for the token request. This is used to identify the client. (Optional)""" | ||
| client_secret: str = field( | ||
| client_secret: str | None = field( | ||
| validator=validators.instance_of((str, type(None))), default=None, repr=False | ||
| ) | ||
| """The client secret for the token request. | ||
| This is used to authenticate the client. (Optional)""" | ||
| _token: AccessToken = field( | ||
| _token: AccessToken | None = field( | ||
| validator=validators.instance_of((AccessToken, type(None))), | ||
| init=False, | ||
| repr=False, | ||
|
|
@@ -572,7 +590,7 @@ def get_session(self) -> Session: | |
| } | ||
| session.headers["Authorization"] = f"Bearer {self._get_token(payload)}" | ||
|
|
||
| if self._token.is_expired and self._token.refresh_token is not None: | ||
| if self._token and self._token.is_expired and self._token.refresh_token is not None: | ||
| session = Session() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This creates a fresh session without the I propose to include setting |
||
| payload = { | ||
| "grant_type": "refresh_token", | ||
|
|
@@ -638,7 +656,7 @@ class OAuth2ClientFlowCredentials(Credentials): | |
| """The client secret for the token request. This is used to authenticate the client.""" | ||
| timeout: int = field(validator=validators.instance_of(int), default=1) | ||
| """The timeout for the token request. Defaults to 1 second.""" | ||
| _token: AccessToken = field( | ||
| _token: AccessToken | None = field( | ||
| validator=validators.instance_of((AccessToken, type(None))), init=False, repr=False | ||
| ) | ||
|
|
||
|
|
@@ -658,7 +676,7 @@ def get_session(self) -> Session: | |
|
|
||
| """ | ||
| session = super().get_session() | ||
| if "Authorization" in session.headers and self._token.is_expired: | ||
| if "Authorization" in session.headers and (not self._token or self._token.is_expired): | ||
| session.close() | ||
| session = Session() | ||
| if self._no_authorization_header(session): | ||
|
|
@@ -709,7 +727,7 @@ class MTLSCredentials(Credentials): | |
| """path to the client key""" | ||
| cert: str = field(validator=validators.instance_of(str)) | ||
| """path to the client certificate""" | ||
| ca_cert: str = field(validator=validators.instance_of((str, type(None))), default=None) | ||
| ca_cert: str | None = field(validator=validators.instance_of((str, type(None))), default=None) | ||
| """path to a certification authority certificate""" | ||
|
|
||
| def get_session(self): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.