Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ requirements.*
/scripts/
.dmypy.json
tmp
.envrc
71 changes: 51 additions & 20 deletions logprep/connector/http/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
* Responds with 405
"""

import hmac
import logging
import multiprocessing as mp
import queue
Expand All @@ -109,13 +110,18 @@
HTTPMethodNotAllowed,
HTTPTooManyRequests,
HTTPUnauthorized,
Request,
)

from logprep.abc.input import FatalInputError, Input
from logprep.factory_error import InvalidConfigurationError
from logprep.metrics.metrics import CounterMetric, GaugeMetric
from logprep.util import http, rstr
from logprep.util.credentials import Credentials, CredentialsFactory
from logprep.util.credentials import (
BasicAuthCredentials,
Credentials,
CredentialsFactory,
)
from logprep.util.helper import add_fields_to

logger = logging.getLogger("HTTPInput")
Expand All @@ -126,15 +132,24 @@ def basic_auth(func: Callable):
Will raise 401 on wrong credentials or missing Authorization-Header"""

async def func_wrapper(*args, **kwargs):
endpoint = args[0]
req = args[1]
endpoint: HttpEndpoint = args[0]
req: Request = args[1]
if endpoint.credentials:
auth_request_header = req.get_header("Authorization")
if not auth_request_header:
if not req.auth or not isinstance(req.auth, str):
raise HTTPUnauthorized
basic_string = req.auth
if endpoint.basicauth_b64 not in basic_string:

auth_header_value_b64 = req.auth
lowered_auth_header = auth_header_value_b64.lower()
if lowered_auth_header.startswith("basic"):
auth_header_value_b64 = auth_header_value_b64[5:]
auth_header_value_b64 = auth_header_value_b64.lstrip()

for basicauth_b64 in endpoint.basicauth_b64:
if hmac.compare_digest(basicauth_b64, auth_header_value_b64):
break
else:
raise HTTPUnauthorized

func_wrapper = await func(*args, **kwargs)
return func_wrapper

Expand Down Expand Up @@ -249,7 +264,7 @@ def __init__(
original_event_field: dict[str, str] | None,
collect_meta: bool,
metafield_name: str,
credentials: Credentials | None,
credentials: list[Credentials] | Credentials | None,
metrics: "HttpInput.Metrics",
copy_headers_to_logs: set[str],
) -> None:
Expand All @@ -259,14 +274,27 @@ def __init__(

# Deprecated
self.collect_meta = collect_meta

self.metafield_name = metafield_name
self.credentials = credentials
self.metrics = metrics
if self.credentials:
# TODO what about other credential types?
self.basicauth_b64 = b64encode(
f"{self.credentials.username}:{self.credentials.password}".encode("utf-8") # type: ignore
).decode("utf-8")
self.basicauth_b64: list[str] = []

if self.credentials and isinstance(self.credentials, list):
for cred in self.credentials:
if isinstance(cred, BasicAuthCredentials):
self.basicauth_b64.append(
b64encode(f"{cred.username}:{cred.password}".encode("utf-8")).decode(
"utf-8"
)
)
elif self.credentials:
if isinstance(self.credentials, BasicAuthCredentials):
self.basicauth_b64.append(
b64encode(
f"{self.credentials.username}:{self.credentials.password}".encode("utf-8")
).decode("utf-8")
)

def collect_metrics(self):
"""Increment number of requests"""
Expand Down Expand Up @@ -453,7 +481,7 @@ class Config(Input.Config):
"""

message_backlog_size: int = field(
validator=validators.instance_of((int, float)), default=15000
validator=validators.instance_of(int), default=15000, converter=lambda x: int(x)
)
"""Configures maximum size of input message queue for this connector. When limit is reached
the server will answer with 429 Too Many Requests. For reasonable throughput this shouldn't
Expand All @@ -463,9 +491,7 @@ class Config(Input.Config):
copy_headers_to_logs: set[str] = field(
validator=validators.deep_iterable(
member_validator=validators.instance_of(str),
iterable_validator=validators.or_(
validators.instance_of(set), validators.instance_of(list)
),
iterable_validator=validators.instance_of(set),
),
converter=set,
factory=lambda: set(DEFAULT_META_HEADERS),
Expand Down Expand Up @@ -549,6 +575,11 @@ def config(self) -> Config:
"""Provides the properly typed rule configuration object"""
return typing.cast(HttpInput.Config, self._config)

@property
def _typed_metrics(self) -> Metrics:
"""Returns metrics as typed HttpInput.Metrics"""
return typing.cast(HttpInput.Metrics, self.metrics)

def setup(self) -> None:
"""setup starts the actual functionality of this connector.
By checking against pipeline_index we're assuring this connector
Expand Down Expand Up @@ -587,7 +618,7 @@ def setup(self) -> None:
collect_meta,
metafield_name,
credentials,
self.metrics,
self._typed_metrics,
copy_headers_to_logs,
)

Expand All @@ -609,7 +640,7 @@ def _get_event(self, timeout: float) -> tuple:
"""Returns the first message from the queue"""
messages = typing.cast(Queue, self.messages)

self.metrics.message_backlog_size += messages.qsize()
self._typed_metrics.message_backlog_size += messages.qsize()
try:
message = messages.get(timeout=timeout)
raw_message = str(message).encode("utf8")
Expand Down Expand Up @@ -651,7 +682,7 @@ def health(self) -> bool:
).raise_for_status()
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as error:
logger.error("Health check failed for endpoint: %s due to %s", endpoint, str(error))
self.metrics.number_of_errors += 1
self._typed_metrics.number_of_errors += 1
return False

return super().health()
46 changes: 32 additions & 14 deletions logprep/util/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def from_target(cls, target_url: str) -> "Credentials | None":
return credentials

@classmethod
def from_endpoint(cls, target_endpoint: str) -> "Credentials | None":
def from_endpoint(cls, target_endpoint: str) -> "list[Credentials] | Credentials | None":
"""Factory method to create a credentials object based on the credentials stored in the
environment variable :code:`LOGPREP_CREDENTIALS_FILE`.
Based on these credentials the expected authentication method is chosen and represented
Expand All @@ -187,8 +187,14 @@ def from_endpoint(cls, target_endpoint: str) -> "Credentials | None":
endpoint_credentials = credentials_file.input.get("endpoints")
if endpoint_credentials is None:
return None
credential_mapping: dict | None = endpoint_credentials.get(target_endpoint)
credentials = cls.from_dict(credential_mapping)
credential_mapping: list | dict | None = endpoint_credentials.get(target_endpoint)

credentials: list[Credentials] | Credentials | None = None
if isinstance(credential_mapping, dict):
credentials = cls.from_dict(credential_mapping)
elif isinstance(credential_mapping, list):
credentials = cls.from_list(credential_mapping)

return credentials

@staticmethod
Expand Down Expand Up @@ -250,6 +256,16 @@ def _resolve_secret_content(credential_mapping: dict):
credential_mapping.pop(f"{credential_type}_file")
credential_mapping.update(secret_content)

@classmethod
def from_list(cls, credential_mapping: list[dict | None]) -> "list[Credentials] | None":
creds: list[Credentials] = []
for credential in credential_mapping:
cred = cls.from_dict(credential)
if isinstance(cred, Credentials):
creds.append(cred)

return creds

@classmethod
def from_dict(cls, credential_mapping: dict | None) -> "Credentials | None":
"""matches the given credentials of the credentials mapping
Expand Down Expand Up @@ -391,11 +407,11 @@ class AccessToken:

token: str = field(validator=validators.instance_of(str), repr=False)
"""token used for authentication against the target"""
expiry_time: datetime = field(
expiry_time: datetime | None = field(
validator=validators.instance_of((datetime, type(None))), init=False
)
"""time when token is expired"""
refresh_token: str = field(
refresh_token: str | None = field(
validator=validators.instance_of((str, type(None))), default=None, repr=False
)
"""is used incase the token is expired"""
Expand All @@ -415,7 +431,7 @@ def __str__(self) -> str:
@property
def is_expired(self) -> bool:
"""Checks if the token is already expired."""
if self.expires_in == 0:
if self.expires_in == 0 or not self.expiry_time:
return False
return datetime.now() > self.expiry_time

Expand All @@ -426,7 +442,9 @@ class Credentials:

_logger = logging.getLogger("Credentials")

_session: Session = field(validator=validators.instance_of((Session, type(None))), default=None)
_session: Session | None = field(
validator=validators.instance_of((Session, type(None))), default=None
)

def get_session(self):
"""returns session with retry configuration"""
Expand Down Expand Up @@ -548,14 +566,14 @@ class OAuth2PasswordFlowCredentials(Credentials):
"""the username for the token request"""
timeout: int = field(validator=validators.instance_of(int), default=1)
"""The timeout for the token request. Defaults to 1 second."""
client_id: str = field(validator=validators.instance_of((str, type(None))), default=None)
client_id: str | None = field(validator=validators.instance_of((str, type(None))), default=None)
"""The client id for the token request. This is used to identify the client. (Optional)"""
client_secret: str = field(
client_secret: str | None = field(
validator=validators.instance_of((str, type(None))), default=None, repr=False
)
"""The client secret for the token request.
This is used to authenticate the client. (Optional)"""
_token: AccessToken = field(
_token: AccessToken | None = field(
validator=validators.instance_of((AccessToken, type(None))),
init=False,
repr=False,
Expand All @@ -572,7 +590,7 @@ def get_session(self) -> Session:
}
session.headers["Authorization"] = f"Bearer {self._get_token(payload)}"

if self._token.is_expired and self._token.refresh_token is not None:
if self._token and self._token.is_expired and self._token.refresh_token is not None:
session = Session()
Copy link
Collaborator

@mhoff mhoff Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This creates a fresh session without the super().get_session() semantics which ensure a retry behavior is configured. I would propose to refactor Credentials.get_session() to call a new method self._create_and_save_session(). This new method can then be re-used in this line to ensure a retry configuration is being set properly.

I propose to include setting self._session in this new method as well (in contrast to a pure self._create_session() method) because the management of this attribute is an internal aspect of the super class.

payload = {
"grant_type": "refresh_token",
Expand Down Expand Up @@ -638,7 +656,7 @@ class OAuth2ClientFlowCredentials(Credentials):
"""The client secret for the token request. This is used to authenticate the client."""
timeout: int = field(validator=validators.instance_of(int), default=1)
"""The timeout for the token request. Defaults to 1 second."""
_token: AccessToken = field(
_token: AccessToken | None = field(
validator=validators.instance_of((AccessToken, type(None))), init=False, repr=False
)

Expand All @@ -658,7 +676,7 @@ def get_session(self) -> Session:

"""
session = super().get_session()
if "Authorization" in session.headers and self._token.is_expired:
if "Authorization" in session.headers and (not self._token or self._token.is_expired):
session.close()
session = Session()
if self._no_authorization_header(session):
Expand Down Expand Up @@ -709,7 +727,7 @@ class MTLSCredentials(Credentials):
"""path to the client key"""
cert: str = field(validator=validators.instance_of(str))
"""path to the client certificate"""
ca_cert: str = field(validator=validators.instance_of((str, type(None))), default=None)
ca_cert: str | None = field(validator=validators.instance_of((str, type(None))), default=None)
"""path to a certification authority certificate"""

def get_session(self):
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/connector/test_http_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def create_credentials(tmp_path):
/.*/[A-Z]{{2}}/json$:
username: user
password: password
/auth-json-two-creds:
- username: user
password: password
- username: user2
password: password2

""")

return str(credential_file_path)
Expand All @@ -65,6 +71,7 @@ class TestHttpConnector(BaseInputTestCase):
"/auth-json-secret": "json",
"/auth-json-file": "json",
"/[A-Za-z0-9]*/[A-Z]{2}/json$": "json",
"/auth-json-two-creds": "json",
},
}

Expand Down Expand Up @@ -503,9 +510,11 @@ def test_endpoint_returns_200_on_correct_authorization_with_password_from_file(
data = {"message": "my log message"}
with mock.patch.dict("os.environ", mock_env):
new_connector = Factory.create({"test connector": self.CONFIG})
assert isinstance(new_connector, HttpInput)
new_connector.pipeline_index = 1
new_connector.setup()
headers = {"Authorization": _basic_auth_str("user", "file_password")}
assert new_connector.app
client = testing.TestClient(new_connector.app, headers=headers)
resp = client.post("/auth-json-file", body=json.dumps(data))
assert resp.status_code == 200
Expand Down Expand Up @@ -536,6 +545,34 @@ def test_endpoint_returns_200_on_correct_authorization_for_subpath(self, credent
resp = client.post("/auth-json-secret/AB/json", body=json.dumps(data))
assert resp.status_code == 200

def test_endpoint_returns_200_on_correct_authorization_for_subpath_and_second_credential(
self, credentials_file_path
):
mock_env = {ENV_NAME_LOGPREP_CREDENTIALS_FILE: credentials_file_path}
data = {"message": "my log message"}
with mock.patch.dict("os.environ", mock_env):
new_connector = Factory.create({"test connector": self.CONFIG})
new_connector.pipeline_index = 1
new_connector.setup()
headers = {"Authorization": _basic_auth_str("user2", "password2")}
client = testing.TestClient(new_connector.app, headers=headers)
resp = client.post("/auth-json-two-creds", body=json.dumps(data))
assert resp.status_code == 200

def test_endpoint_returns_401_on_wrong_authorization_with_second_credential(
self, credentials_file_path
):
mock_env = {ENV_NAME_LOGPREP_CREDENTIALS_FILE: credentials_file_path}
data = {"message": "my log message"}
with mock.patch.dict("os.environ", mock_env):
new_connector = Factory.create({"test connector": self.CONFIG})
new_connector.pipeline_index = 1
new_connector.setup()
headers = {"Authorization": _basic_auth_str("wrong", "credentials")}
client = testing.TestClient(new_connector.app, headers=headers)
resp = client.post("/auth-json-two-creds", body=json.dumps(data))
assert resp.status_code == 401

def test_two_connector_instances_share_the_same_queue(self):
new_connector = Factory.create({"test connector": self.CONFIG})
assert self.object.messages is new_connector.messages
Expand Down