diff --git a/infrahub_sdk/ctl/check.py b/infrahub_sdk/ctl/check.py index 26373560..a5164173 100644 --- a/infrahub_sdk/ctl/check.py +++ b/infrahub_sdk/ctl/check.py @@ -1,10 +1,8 @@ -import importlib import logging import sys from asyncio import run as aiorun from dataclasses import dataclass from pathlib import Path -from types import ModuleType from typing import Optional, Type import typer @@ -18,6 +16,7 @@ from ..ctl.exceptions import QueryNotFoundError from ..ctl.repository import get_repository_config from ..ctl.utils import catch_exception, execute_graphql_query +from ..exceptions import ModuleImportError from ..schema import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig app = typer.Typer() @@ -27,12 +26,9 @@ @dataclass class CheckModule: name: str - module: ModuleType + check_class: Type[InfrahubCheck] definition: InfrahubCheckDefinitionConfig - def get_check(self) -> Type[InfrahubCheck]: - return getattr(self.module, self.definition.class_name) - @app.callback() def callback() -> None: @@ -67,11 +63,7 @@ def run( check_definitions = repository_config.check_definitions if name: - check_definitions = [check for check in repository_config.check_definitions if check.name == name] # pylint: disable=not-an-iterable - if not check_definitions: - console.print(f"[red]Unable to find requested transform: {name}") - list_checks(repository_config=repository_config) - return + check_definitions = [repository_config.get_check_definition(name=name)] check_modules = get_modules(check_definitions=check_definitions) aiorun( @@ -99,7 +91,7 @@ async def run_check( output = "stdout" if format_json else None log = logging.getLogger("infrahub") passed = True - check_class = check_module.get_check() + check_class = check_module.check_class check = check_class(client=client, params=params, output=output, root_directory=path, branch=branch) param_log = f" - {params}" if params else "" try: @@ -231,25 +223,19 @@ async def run_checks( def get_modules(check_definitions: list[InfrahubCheckDefinitionConfig]) -> list[CheckModule]: - log = logging.getLogger("infrahub") modules = [] for check_definition in check_definitions: - directory_name = str(check_definition.file_path.parent) module_name = check_definition.file_path.stem - if directory_name not in sys.path: - sys.path.append(directory_name) + relative_path = str(check_definition.file_path.parent) if check_definition.file_path.parent != Path() else None try: - module = importlib.import_module(module_name) - except ModuleNotFoundError: - log.error(f"Unable to load {check_definition.file_path}") - continue - - if check_definition.class_name not in dir(module): - log.error(f"{check_definition.class_name} class not found in {check_definition.file_path}") - continue - modules.append(CheckModule(name=module_name, module=module, definition=check_definition)) + check_class = check_definition.load_class(import_root=str(Path.cwd()), relative_path=relative_path) + except ModuleImportError as exc: + console.print(f"[red]{exc.message}") + raise typer.Exit(1) from exc + + modules.append(CheckModule(name=module_name, check_class=check_class, definition=check_definition)) return modules diff --git a/infrahub_sdk/schema.py b/infrahub_sdk/schema.py index 5d3a6bbc..e692285e 100644 --- a/infrahub_sdk/schema.py +++ b/infrahub_sdk/schema.py @@ -11,6 +11,7 @@ from typing_extensions import TypeAlias from ._importer import import_module +from .checks import InfrahubCheck from .exceptions import ( InvalidResponseError, ModuleImportError, @@ -89,6 +90,19 @@ class InfrahubCheckDefinitionConfig(InfrahubRepositoryConfigElement): ) class_name: str = Field(default="Check", description="The name of the check class to run.") + def load_class(self, import_root: Optional[str] = None, relative_path: Optional[str] = None) -> type[InfrahubCheck]: + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) + + if self.class_name not in dir(module): + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") + + check_class = getattr(module, self.class_name) + + if not issubclass(check_class, InfrahubCheck): + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Check") + + return check_class + class InfrahubGeneratorDefinitionConfig(InfrahubRepositoryConfigElement): model_config = ConfigDict(extra="forbid")