Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ requirements.*
/scripts/
.dmypy.json
tmp
.envrc
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
### Breaking

### Features
* make it possible to assign multiple credentials to a single endpoint

### Improvements
* improve http endpoint security by fully checking basic auth hashes, and doing that in a time constant manner to not expose secrets

### Bugfix

Expand Down
75 changes: 54 additions & 21 deletions logprep/connector/http/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
/firstendpoint: json
/second*: plaintext
/(third|fourth)/endpoint: jsonl
/endpoint_with_multiple_credentials: json

The endpoint config supports regex and wildcard patterns:
* :code:`/second*`: matches everything after asterisk
Expand Down Expand Up @@ -56,6 +57,11 @@
/second*:
username: user
password: secret_password
/endpoint_with_multiple_credentials:
- username: user
password: secret_password
- username: user2
password_file: examples/exampledata/config/user_password.txt

You can choose between a plain secret with the key :code:`password` or a filebased secret
with the key :code:`password_file`.
Expand Down Expand Up @@ -86,6 +92,7 @@
* Responds with 405
"""

import hmac
import logging
import multiprocessing as mp
import queue
Expand All @@ -109,13 +116,18 @@
HTTPMethodNotAllowed,
HTTPTooManyRequests,
HTTPUnauthorized,
Request,
)

from logprep.abc.input import FatalInputError, Input
from logprep.factory_error import InvalidConfigurationError
from logprep.metrics.metrics import CounterMetric, GaugeMetric
from logprep.util import http, rstr
from logprep.util.credentials import Credentials, CredentialsFactory
from logprep.util.credentials import (
BasicAuthCredentials,
Credentials,
CredentialsFactory,
)
from logprep.util.helper import add_fields_to

logger = logging.getLogger("HTTPInput")
Expand All @@ -126,15 +138,24 @@ def basic_auth(func: Callable):
Will raise 401 on wrong credentials or missing Authorization-Header"""

async def func_wrapper(*args, **kwargs):
endpoint = args[0]
req = args[1]
endpoint: HttpEndpoint = args[0]
req: Request = args[1]
if endpoint.credentials:
auth_request_header = req.get_header("Authorization")
if not auth_request_header:
if not req.auth or not isinstance(req.auth, str):
raise HTTPUnauthorized
basic_string = req.auth
if endpoint.basicauth_b64 not in basic_string:

auth_header_value_b64 = req.auth
lowered_auth_header = auth_header_value_b64.lower()
if lowered_auth_header.startswith("basic"):
auth_header_value_b64 = auth_header_value_b64[5:]
auth_header_value_b64 = auth_header_value_b64.lstrip()

for basicauth_b64 in endpoint.basicauth_b64:
if hmac.compare_digest(basicauth_b64, auth_header_value_b64):
break
else:
raise HTTPUnauthorized

func_wrapper = await func(*args, **kwargs)
return func_wrapper

Expand Down Expand Up @@ -249,7 +270,7 @@ def __init__(
original_event_field: dict[str, str] | None,
collect_meta: bool,
metafield_name: str,
credentials: Credentials | None,
credentials: list[Credentials] | Credentials | None,
metrics: "HttpInput.Metrics",
copy_headers_to_logs: set[str],
) -> None:
Expand All @@ -259,14 +280,23 @@ def __init__(

# Deprecated
self.collect_meta = collect_meta

self.metafield_name = metafield_name
self.credentials = credentials
self.metrics = metrics
if self.credentials:
# TODO what about other credential types?
self.basicauth_b64 = b64encode(
f"{self.credentials.username}:{self.credentials.password}".encode("utf-8") # type: ignore
).decode("utf-8")
self.basicauth_b64: list[str] = []

self.credentials = None

if credentials:
credentials = [credentials] if isinstance(credentials, Credentials) else credentials
for cred in credentials:
if isinstance(cred, BasicAuthCredentials):
self.basicauth_b64.append(
b64encode(f"{cred.username}:{cred.password}".encode("utf-8")).decode(
"utf-8"
)
)
self.credentials = credentials

def collect_metrics(self):
"""Increment number of requests"""
Expand Down Expand Up @@ -453,7 +483,7 @@ class Config(Input.Config):
"""

message_backlog_size: int = field(
validator=validators.instance_of((int, float)), default=15000
validator=validators.instance_of(int), default=15000, converter=int
)
"""Configures maximum size of input message queue for this connector. When limit is reached
the server will answer with 429 Too Many Requests. For reasonable throughput this shouldn't
Expand All @@ -463,9 +493,7 @@ class Config(Input.Config):
copy_headers_to_logs: set[str] = field(
validator=validators.deep_iterable(
member_validator=validators.instance_of(str),
iterable_validator=validators.or_(
validators.instance_of(set), validators.instance_of(list)
),
iterable_validator=validators.instance_of(set),
),
converter=set,
factory=lambda: set(DEFAULT_META_HEADERS),
Expand Down Expand Up @@ -549,6 +577,11 @@ def config(self) -> Config:
"""Provides the properly typed rule configuration object"""
return typing.cast(HttpInput.Config, self._config)

@property
def _typed_metrics(self) -> Metrics:
"""Returns metrics as typed HttpInput.Metrics"""
return typing.cast(HttpInput.Metrics, self.metrics)

def setup(self) -> None:
"""setup starts the actual functionality of this connector.
By checking against pipeline_index we're assuring this connector
Expand Down Expand Up @@ -587,7 +620,7 @@ def setup(self) -> None:
collect_meta,
metafield_name,
credentials,
self.metrics,
self._typed_metrics,
copy_headers_to_logs,
)

Expand All @@ -609,7 +642,7 @@ def _get_event(self, timeout: float) -> tuple:
"""Returns the first message from the queue"""
messages = typing.cast(Queue, self.messages)

self.metrics.message_backlog_size += messages.qsize()
self._typed_metrics.message_backlog_size += messages.qsize()
try:
message = messages.get(timeout=timeout)
raw_message = str(message).encode("utf8")
Expand Down Expand Up @@ -651,7 +684,7 @@ def health(self) -> bool:
).raise_for_status()
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as error:
logger.error("Health check failed for endpoint: %s due to %s", endpoint, str(error))
self.metrics.number_of_errors += 1
self._typed_metrics.number_of_errors += 1
return False

return super().health()
6 changes: 6 additions & 0 deletions logprep/ng/connector/http/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
/firstendpoint: json
/second*: plaintext
/(third|fourth)/endpoint: jsonl
/endpoint_with_multiple_credentials: json

The endpoint config supports regex and wildcard patterns:
* :code:`/second*`: matches everything after asterisk
Expand Down Expand Up @@ -56,6 +57,11 @@
/second*:
username: user
password: secret_password
/endpoint_with_multiple_credentials:
- username: user
password: secret_password
- username: user2
password_file: examples/exampledata/config/user_password.txt

You can choose between a plain secret with the key :code:`password` or a filebased secret
with the key :code:`password_file`.
Expand Down
44 changes: 31 additions & 13 deletions logprep/util/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def from_target(cls, target_url: str) -> "Credentials | None":
return credentials

@classmethod
def from_endpoint(cls, target_endpoint: str) -> "Credentials | None":
def from_endpoint(cls, target_endpoint: str) -> "list[Credentials] | Credentials | None":
"""Factory method to create a credentials object based on the credentials stored in the
environment variable :code:`LOGPREP_CREDENTIALS_FILE`.
Based on these credentials the expected authentication method is chosen and represented
Expand All @@ -187,8 +187,14 @@ def from_endpoint(cls, target_endpoint: str) -> "Credentials | None":
endpoint_credentials = credentials_file.input.get("endpoints")
if endpoint_credentials is None:
return None
credential_mapping: dict | None = endpoint_credentials.get(target_endpoint)
credentials = cls.from_dict(credential_mapping)
credential_mapping: list | dict | None = endpoint_credentials.get(target_endpoint)

credentials: list[Credentials] | Credentials | None = None
if isinstance(credential_mapping, dict):
credentials = cls.from_dict(credential_mapping)
elif isinstance(credential_mapping, list):
credentials = cls.from_list(credential_mapping)

return credentials

@staticmethod
Expand Down Expand Up @@ -250,6 +256,16 @@ def _resolve_secret_content(credential_mapping: dict):
credential_mapping.pop(f"{credential_type}_file")
credential_mapping.update(secret_content)

@classmethod
def from_list(cls, credential_mapping: list[dict | None]) -> "list[Credentials] | None":
creds: list[Credentials] = []
for credential in credential_mapping:
cred = cls.from_dict(credential)
if isinstance(cred, Credentials):
creds.append(cred)

return creds if len(creds) > 0 else None

@classmethod
def from_dict(cls, credential_mapping: dict | None) -> "Credentials | None":
"""matches the given credentials of the credentials mapping
Expand Down Expand Up @@ -392,10 +408,10 @@ class AccessToken:
token: str = field(validator=validators.instance_of(str), repr=False)
"""token used for authentication against the target"""
expiry_time: datetime = field(
validator=validators.instance_of((datetime, type(None))), init=False
validator=validators.instance_of(datetime), init=False, default=datetime.now()
)
"""time when token is expired"""
refresh_token: str = field(
refresh_token: str | None = field(
validator=validators.instance_of((str, type(None))), default=None, repr=False
)
"""is used incase the token is expired"""
Expand Down Expand Up @@ -426,7 +442,9 @@ class Credentials:

_logger = logging.getLogger("Credentials")

_session: Session = field(validator=validators.instance_of((Session, type(None))), default=None)
_session: Session | None = field(
validator=validators.instance_of((Session, type(None))), default=None
)

def get_session(self):
"""returns session with retry configuration"""
Expand Down Expand Up @@ -548,14 +566,14 @@ class OAuth2PasswordFlowCredentials(Credentials):
"""the username for the token request"""
timeout: int = field(validator=validators.instance_of(int), default=1)
"""The timeout for the token request. Defaults to 1 second."""
client_id: str = field(validator=validators.instance_of((str, type(None))), default=None)
client_id: str | None = field(validator=validators.instance_of((str, type(None))), default=None)
"""The client id for the token request. This is used to identify the client. (Optional)"""
client_secret: str = field(
client_secret: str | None = field(
validator=validators.instance_of((str, type(None))), default=None, repr=False
)
"""The client secret for the token request.
This is used to authenticate the client. (Optional)"""
_token: AccessToken = field(
_token: AccessToken | None = field(
validator=validators.instance_of((AccessToken, type(None))),
init=False,
repr=False,
Expand All @@ -572,7 +590,7 @@ def get_session(self) -> Session:
}
session.headers["Authorization"] = f"Bearer {self._get_token(payload)}"

if self._token.is_expired and self._token.refresh_token is not None:
if self._token and self._token.is_expired and self._token.refresh_token is not None:
session = Session()
Copy link
Collaborator

@mhoff mhoff Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a fresh session without the super().get_session() semantics which ensure a retry behavior is configured. I would propose to refactor Credentials.get_session() to call a new method self._create_and_save_session(). This new method can then be re-used in this line to ensure a retry configuration is being set properly.

I propose to include setting self._session in this new method as well (in contrast to a pure self._create_session() method) because the management of this attribute is an internal aspect of the super class.

payload = {
"grant_type": "refresh_token",
Expand Down Expand Up @@ -638,7 +656,7 @@ class OAuth2ClientFlowCredentials(Credentials):
"""The client secret for the token request. This is used to authenticate the client."""
timeout: int = field(validator=validators.instance_of(int), default=1)
"""The timeout for the token request. Defaults to 1 second."""
_token: AccessToken = field(
_token: AccessToken | None = field(
validator=validators.instance_of((AccessToken, type(None))), init=False, repr=False
)

Expand All @@ -658,7 +676,7 @@ def get_session(self) -> Session:

"""
session = super().get_session()
if "Authorization" in session.headers and self._token.is_expired:
if "Authorization" in session.headers and (not self._token or self._token.is_expired):
session.close()
session = Session()
if self._no_authorization_header(session):
Expand Down Expand Up @@ -709,7 +727,7 @@ class MTLSCredentials(Credentials):
"""path to the client key"""
cert: str = field(validator=validators.instance_of(str))
"""path to the client certificate"""
ca_cert: str = field(validator=validators.instance_of((str, type(None))), default=None)
ca_cert: str | None = field(validator=validators.instance_of((str, type(None))), default=None)
"""path to a certification authority certificate"""

def get_session(self):
Expand Down
Loading