diff --git a/infrahub_sdk/ctl/check.py b/infrahub_sdk/ctl/check.py index 0626d884..92d852e6 100644 --- a/infrahub_sdk/ctl/check.py +++ b/infrahub_sdk/ctl/check.py @@ -11,10 +11,9 @@ from rich.console import Console from rich.logging import RichHandler -from ..ctl import config from ..ctl.client import initialize_client from ..ctl.exceptions import QueryNotFoundError -from ..ctl.repository import get_repository_config +from ..ctl.repository import find_repository_config_file, get_repository_config from ..ctl.utils import catch_exception, execute_graphql_query from ..exceptions import ModuleImportError @@ -59,7 +58,7 @@ def run( FORMAT = "%(message)s" logging.basicConfig(level=log_level, format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]) - repository_config = get_repository_config(Path(config.INFRAHUB_REPO_CONFIG_FILE)) + repository_config = get_repository_config(find_repository_config_file()) if list_available: list_checks(repository_config=repository_config) diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 605743fa..13c1e671 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -20,7 +20,6 @@ from .. import __version__ as sdk_version from ..async_typer import AsyncTyper -from ..ctl import config from ..ctl.branch import app as branch_app from ..ctl.check import run as run_check from ..ctl.client import initialize_client, initialize_client_sync @@ -30,7 +29,7 @@ from ..ctl.object import app as object_app from ..ctl.render import list_jinja2_transforms, print_template_errors from ..ctl.repository import app as repository_app -from ..ctl.repository import get_repository_config +from ..ctl.repository import find_repository_config_file, get_repository_config from ..ctl.schema import app as schema_app from ..ctl.transform import list_transforms from ..ctl.utils import ( @@ -260,7 +259,7 @@ async def render( """Render a local Jinja2 Transform for debugging purpose.""" variables_dict = parse_cli_vars(variables) - repository_config = get_repository_config(Path(config.INFRAHUB_REPO_CONFIG_FILE)) + repository_config = get_repository_config(find_repository_config_file()) if list_available or not transform_name: list_jinja2_transforms(config=repository_config) @@ -270,7 +269,7 @@ async def render( try: transform_config = repository_config.get_jinja2_transform(name=transform_name) except KeyError as exc: - console.print(f'[red]Unable to find "{transform_name}" in {config.INFRAHUB_REPO_CONFIG_FILE}') + console.print(f'[red]Unable to find "{transform_name}" in repository config file') list_jinja2_transforms(config=repository_config) raise typer.Exit(1) from exc @@ -310,7 +309,7 @@ def transform( """Render a local transform (TransformPython) for debugging purpose.""" variables_dict = parse_cli_vars(variables) - repository_config = get_repository_config(Path(config.INFRAHUB_REPO_CONFIG_FILE)) + repository_config = get_repository_config(find_repository_config_file()) if list_available or not transform_name: list_transforms(config=repository_config) @@ -469,7 +468,7 @@ def info( # noqa: PLR0915 pretty_model = Pretty(client.config.model_dump(), expand_all=True) layout["client_info"].update(Panel(pretty_model, title="Client Info")) - # Infrahub information planel + # Infrahub information panel infrahub_info = Table(show_header=False, box=None) if info["user_info"]: infrahub_info.add_row("User:", info["user_info"]["AccountProfile"]["display_label"]) diff --git a/infrahub_sdk/ctl/config.py b/infrahub_sdk/ctl/config.py index 9d3b6488..37d0571b 100644 --- a/infrahub_sdk/ctl/config.py +++ b/infrahub_sdk/ctl/config.py @@ -12,6 +12,7 @@ DEFAULT_CONFIG_FILE = "infrahubctl.toml" ENVVAR_CONFIG_FILE = "INFRAHUBCTL_CONFIG" INFRAHUB_REPO_CONFIG_FILE = ".infrahub.yml" +INFRAHUB_REPO_CONFIG_FILE_ALT = ".infrahub.yaml" class Settings(BaseSettings): @@ -69,7 +70,7 @@ def load(self, config_file: str | Path = "infrahubctl.toml", config_data: dict | def load_and_exit(self, config_file: str | Path = "infrahubctl.toml", config_data: dict | None = None) -> None: """Calls load, but wraps it in a try except block. - This is done to handle a ValidationErorr which is raised when settings are specified but invalid. + This is done to handle a ValidationError which is raised when settings are specified but invalid. In such cases, a message is printed to the screen indicating the settings which don't pass validation. Args: diff --git a/infrahub_sdk/ctl/generator.py b/infrahub_sdk/ctl/generator.py index c75b5acb..c0f60c52 100644 --- a/infrahub_sdk/ctl/generator.py +++ b/infrahub_sdk/ctl/generator.py @@ -6,9 +6,8 @@ import typer from rich.console import Console -from ..ctl import config from ..ctl.client import initialize_client -from ..ctl.repository import get_repository_config +from ..ctl.repository import find_repository_config_file, get_repository_config from ..ctl.utils import execute_graphql_query, init_logging, parse_cli_vars from ..exceptions import ModuleImportError from ..node import InfrahubNode @@ -26,7 +25,7 @@ async def run( variables: Optional[list[str]] = None, ) -> None: init_logging(debug=debug) - repository_config = get_repository_config(Path(config.INFRAHUB_REPO_CONFIG_FILE)) + repository_config = get_repository_config(find_repository_config_file()) if list_available or not generator_name: list_generators(repository_config=repository_config) diff --git a/infrahub_sdk/ctl/repository.py b/infrahub_sdk/ctl/repository.py index d23f8484..5c9423d1 100644 --- a/infrahub_sdk/ctl/repository.py +++ b/infrahub_sdk/ctl/repository.py @@ -24,11 +24,49 @@ console = Console() +def find_repository_config_file(base_path: Path | None = None) -> Path: + """Find the repository config file, checking for both .yml and .yaml extensions. + + Args: + base_path: Base directory to search in. If None, uses current directory. + + Returns: + Path to the config file. + + Raises: + FileNotFoundError: If neither .infrahub.yml nor .infrahub.yaml exists. + """ + if base_path is None: + base_path = Path() + + yml_path = base_path / ".infrahub.yml" + yaml_path = base_path / ".infrahub.yaml" + + # Prefer .yml if both exist + if yml_path.exists(): + return yml_path + if yaml_path.exists(): + return yaml_path + # For backward compatibility, return .yml path for error messages + return yml_path + + def get_repository_config(repo_config_file: Path) -> InfrahubRepositoryConfig: + # If the file doesn't exist, try to find it with alternate extension + if not repo_config_file.exists(): + if repo_config_file.name == ".infrahub.yml": + alt_path = repo_config_file.parent / ".infrahub.yaml" + if alt_path.exists(): + repo_config_file = alt_path + elif repo_config_file.name == ".infrahub.yaml": + alt_path = repo_config_file.parent / ".infrahub.yml" + if alt_path.exists(): + repo_config_file = alt_path + try: config_file_data = load_repository_config_file(repo_config_file) except FileNotFoundError as exc: - console.print(f"[red]File not found {exc}") + console.print(f"[red]File not found {exc} (also checked for .infrahub.yml and .infrahub.yaml)") raise typer.Exit(1) from exc except FileNotValidError as exc: console.print(f"[red]{exc.message}") diff --git a/infrahub_sdk/pytest_plugin/plugin.py b/infrahub_sdk/pytest_plugin/plugin.py index 871ba45b..64c2080b 100644 --- a/infrahub_sdk/pytest_plugin/plugin.py +++ b/infrahub_sdk/pytest_plugin/plugin.py @@ -9,7 +9,7 @@ from .. import InfrahubClientSync from ..utils import is_valid_url from .loader import InfrahubYamlFile -from .utils import load_repository_config +from .utils import find_repository_config_file, load_repository_config def pytest_addoption(parser: Parser) -> None: @@ -18,9 +18,9 @@ def pytest_addoption(parser: Parser) -> None: "--infrahub-repo-config", action="store", dest="infrahub_repo_config", - default=".infrahub.yml", + default=None, metavar="INFRAHUB_REPO_CONFIG_FILE", - help="Infrahub configuration file for the repository (default: %(default)s)", + help="Infrahub configuration file for the repository (.infrahub.yml or .infrahub.yaml)", ) group.addoption( "--infrahub-address", @@ -63,7 +63,10 @@ def pytest_addoption(parser: Parser) -> None: def pytest_sessionstart(session: Session) -> None: - session.infrahub_config_path = Path(session.config.option.infrahub_repo_config) # type: ignore[attr-defined] + if session.config.option.infrahub_repo_config: + session.infrahub_config_path = Path(session.config.option.infrahub_repo_config) # type: ignore[attr-defined] + else: + session.infrahub_config_path = find_repository_config_file() # type: ignore[attr-defined] if session.infrahub_config_path.is_file(): # type: ignore[attr-defined] session.infrahub_repo_config = load_repository_config(repo_config_file=session.infrahub_config_path) # type: ignore[attr-defined] diff --git a/infrahub_sdk/pytest_plugin/utils.py b/infrahub_sdk/pytest_plugin/utils.py index 2875c23d..b82a2e34 100644 --- a/infrahub_sdk/pytest_plugin/utils.py +++ b/infrahub_sdk/pytest_plugin/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import yaml @@ -6,7 +8,45 @@ from .exceptions import FileNotValidError +def find_repository_config_file(base_path: Path | None = None) -> Path: + """Find the repository config file, checking for both .yml and .yaml extensions. + + Args: + base_path: Base directory to search in. If None, uses current directory. + + Returns: + Path to the config file. + + Raises: + FileNotFoundError: If neither .infrahub.yml nor .infrahub.yaml exists. + """ + if base_path is None: + base_path = Path() + + yml_path = base_path / ".infrahub.yml" + yaml_path = base_path / ".infrahub.yaml" + + # Prefer .yml if both exist + if yml_path.exists(): + return yml_path + if yaml_path.exists(): + return yaml_path + # For backward compatibility, return .yml path for error messages + return yml_path + + def load_repository_config(repo_config_file: Path) -> InfrahubRepositoryConfig: + # If the file doesn't exist, try to find it with alternate extension + if not repo_config_file.exists(): + if repo_config_file.name == ".infrahub.yml": + alt_path = repo_config_file.parent / ".infrahub.yaml" + if alt_path.exists(): + repo_config_file = alt_path + elif repo_config_file.name == ".infrahub.yaml": + alt_path = repo_config_file.parent / ".infrahub.yml" + if alt_path.exists(): + repo_config_file = alt_path + if not repo_config_file.is_file(): raise FileNotFoundError(repo_config_file)