Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions src/configuration/argparse_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from argparse import ArgumentParser, Namespace
from gettext import gettext as _
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, override

if TYPE_CHECKING:
Expand Down Expand Up @@ -38,26 +39,56 @@ def _get_help_string(self, action: argparse.Action) -> str | None:
if action.option_strings or action.nargs in defaulting_nargs:
# append default value
_help += _("\n(default: %(default)s)")
# append environment variable
_help += f"\n(environment variable: {action.envvar})"
# append environment variables
if action.envvar:
_help += f"\n(environment variable: {action.envvar})"
if action.file_envvar:
_help += f"\n(file environment variable: {action.file_envvar})"
# whitespace from each line
return "\n".join([m.lstrip() for m in _help.split("\n")])


class EnvDefault(argparse.Action):
"""Argparse action that allows setting a default from environment variable or file."""

def __init__(
self,
envvar: str,
required: bool = True,
try_file: bool = False,
default: str | None = None,
**kwargs: dict[str, Any],
) -> None:
self.envvar = envvar
if os.environ.get(envvar):
default = os.environ[envvar]
self.file_envvar = f"{envvar}_FILE" if try_file else None

envvar_value = os.environ.get(self.envvar, None)
envvar_file_value = (
os.environ.get(self.file_envvar, None) if self.file_envvar else None
)

if envvar_value is not None:
# enviroment value takes precedence
default = envvar_value
elif envvar_file_value:
# if the environment variable is not set, check for a file specified by the environment variable with the _FILE suffix
default_from_file = self._get_default_from_file(envvar_file_value)
if default_from_file:
default = default_from_file

if required and default:
# If the default is set from environment, it should not be required from command line
required = False
super().__init__(default=default, required=required, **kwargs)
super().__init__(required=required, default=default, **kwargs)

def _get_default_from_file(self, file_path: str) -> str | None:
"""Get the default value from the file specified by the environment variable."""
try:
with Path(file_path).open(encoding="utf-8") as f:
return f.read().strip()
except OSError as e:
msg = f"Error reading file {file_path}, specified by environment variable {self.file_envvar}"
raise argparse.ArgumentTypeError(msg) from e

@override
def __call__(
Expand Down
6 changes: 6 additions & 0 deletions src/configuration/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __add_mqtt_argument_group(
required=False,
action=EnvDefault,
envvar="MQTT_USER",
try_file=True,
type=str,
)
mqtt.add_argument(
Expand All @@ -216,6 +217,7 @@ def __add_mqtt_argument_group(
required=False,
action=EnvDefault,
envvar="MQTT_PASSWORD",
try_file=True,
type=str,
)
mqtt.add_argument(
Expand Down Expand Up @@ -289,6 +291,7 @@ def __add_saic_api_argument_group(
required=True,
action=EnvDefault,
envvar="SAIC_USER",
try_file=True,
type=str,
)
saic_api.add_argument(
Expand All @@ -299,6 +302,7 @@ def __add_saic_api_argument_group(
required=True,
action=EnvDefault,
envvar="SAIC_PASSWORD",
try_file=True,
type=str,
)
saic_api.add_argument(
Expand Down Expand Up @@ -464,6 +468,7 @@ def __add_abrp_argument_group(
required=False,
action=EnvDefault,
envvar="ABRP_API_KEY",
try_file=True,
type=str,
)
abrp_integration.add_argument(
Expand All @@ -475,6 +480,7 @@ def __add_abrp_argument_group(
required=False,
action=EnvDefault,
envvar="ABRP_USER_TOKEN",
try_file=True,
type=str,
)
abrp_integration.add_argument(
Expand Down
77 changes: 77 additions & 0 deletions tests/test_argparse_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

import argparse

import pytest

from configuration.argparse_extensions import EnvDefault


class DummyParser(argparse.ArgumentParser):
"""Dummy ArgumentParser for testing purposes."""

def __init__(self) -> None:
super().__init__(add_help=False)


@pytest.fixture(name="mock_envdefault_file")
def setup_fixture_mock_envdefault_file(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock the _get_default_from_file method to return a fixed value."""
monkeypatch.setattr(
"configuration.argparse_extensions.EnvDefault._get_default_from_file",
lambda *_, **__: "file_env_value",
)


@pytest.fixture(name="mock_env")
def setup_fixture_mock_env(monkeypatch: pytest.MonkeyPatch) -> None: # pylint: disable=unused-argument
"""Mock the environment variable to return a fixed value."""


# pylint: disable-next=unused-argument
def test_envdefault_envvar(
monkeypatch: pytest.MonkeyPatch,
mock_envdefault_file: None, # noqa: ARG001 #pylint: disable=unused-argument
) -> None:
"""Retrieves the value from an environment variable."""
monkeypatch.setenv("TEST_ENV", "env_value")
parser = DummyParser()
parser.add_argument("--test", action=EnvDefault, envvar="TEST_ENV", required=False)
args = parser.parse_args([])
assert args.test == "env_value"


def test_envdefault_file_envvar(
monkeypatch: pytest.MonkeyPatch,
mock_envdefault_file: None, # noqa: ARG001 pylint: disable=unused-argument
) -> None:
"""Retrieve the value from a file specified by an environment variable."""
monkeypatch.setenv("TEST_ENV_FILE", "file_env_value")
parser = DummyParser()
parser.add_argument(
"--test",
action=EnvDefault,
try_file=True,
envvar="TEST_ENV",
required=False,
)
args = parser.parse_args([])
assert args.test == "file_env_value"


def test_envdefault_priority(
monkeypatch: pytest.MonkeyPatch,
mock_envdefault_file: None, # noqa: ARG001 pylint: disable=unused-argument
) -> None:
"""Prioritize environment variable over file."""
monkeypatch.setenv("TEST_ENV", "env_value")
monkeypatch.setenv("TEST_ENV_FILE", "file_env_value")

parser = DummyParser()
parser.add_argument(
"--test",
action=EnvDefault,
try_file=True,
envvar="TEST_ENV",
required=False,
)