|
1 | | -import importlib |
2 | 1 | import logging |
3 | 2 | import sys |
4 | 3 | from asyncio import run as aiorun |
5 | 4 | from dataclasses import dataclass |
6 | 5 | from pathlib import Path |
7 | | -from types import ModuleType |
8 | 6 | from typing import Optional, Type |
9 | 7 |
|
10 | 8 | import typer |
|
18 | 16 | from ..ctl.exceptions import QueryNotFoundError |
19 | 17 | from ..ctl.repository import get_repository_config |
20 | 18 | from ..ctl.utils import catch_exception, execute_graphql_query |
| 19 | +from ..exceptions import ModuleImportError |
21 | 20 | from ..schema import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig |
22 | 21 |
|
23 | 22 | app = typer.Typer() |
|
27 | 26 | @dataclass |
28 | 27 | class CheckModule: |
29 | 28 | name: str |
30 | | - module: ModuleType |
| 29 | + check_class: Type[InfrahubCheck] |
31 | 30 | definition: InfrahubCheckDefinitionConfig |
32 | 31 |
|
33 | | - def get_check(self) -> Type[InfrahubCheck]: |
34 | | - return getattr(self.module, self.definition.class_name) |
35 | | - |
36 | 32 |
|
37 | 33 | @app.callback() |
38 | 34 | def callback() -> None: |
@@ -67,11 +63,7 @@ def run( |
67 | 63 |
|
68 | 64 | check_definitions = repository_config.check_definitions |
69 | 65 | if name: |
70 | | - check_definitions = [check for check in repository_config.check_definitions if check.name == name] # pylint: disable=not-an-iterable |
71 | | - if not check_definitions: |
72 | | - console.print(f"[red]Unable to find requested transform: {name}") |
73 | | - list_checks(repository_config=repository_config) |
74 | | - return |
| 66 | + check_definitions = [repository_config.get_check_definition(name=name)] |
75 | 67 |
|
76 | 68 | check_modules = get_modules(check_definitions=check_definitions) |
77 | 69 | aiorun( |
@@ -99,7 +91,7 @@ async def run_check( |
99 | 91 | output = "stdout" if format_json else None |
100 | 92 | log = logging.getLogger("infrahub") |
101 | 93 | passed = True |
102 | | - check_class = check_module.get_check() |
| 94 | + check_class = check_module.check_class |
103 | 95 | check = check_class(client=client, params=params, output=output, root_directory=path, branch=branch) |
104 | 96 | param_log = f" - {params}" if params else "" |
105 | 97 | try: |
@@ -231,25 +223,19 @@ async def run_checks( |
231 | 223 |
|
232 | 224 |
|
233 | 225 | def get_modules(check_definitions: list[InfrahubCheckDefinitionConfig]) -> list[CheckModule]: |
234 | | - log = logging.getLogger("infrahub") |
235 | 226 | modules = [] |
236 | 227 | for check_definition in check_definitions: |
237 | | - directory_name = str(check_definition.file_path.parent) |
238 | 228 | module_name = check_definition.file_path.stem |
239 | 229 |
|
240 | | - if directory_name not in sys.path: |
241 | | - sys.path.append(directory_name) |
| 230 | + relative_path = str(check_definition.file_path.parent) if check_definition.file_path.parent != Path() else None |
242 | 231 |
|
243 | 232 | try: |
244 | | - module = importlib.import_module(module_name) |
245 | | - except ModuleNotFoundError: |
246 | | - log.error(f"Unable to load {check_definition.file_path}") |
247 | | - continue |
248 | | - |
249 | | - if check_definition.class_name not in dir(module): |
250 | | - log.error(f"{check_definition.class_name} class not found in {check_definition.file_path}") |
251 | | - continue |
252 | | - modules.append(CheckModule(name=module_name, module=module, definition=check_definition)) |
| 233 | + check_class = check_definition.load_class(import_root=str(Path.cwd()), relative_path=relative_path) |
| 234 | + except ModuleImportError as exc: |
| 235 | + console.print(f"[red]{exc.message}") |
| 236 | + raise typer.Exit(1) from exc |
| 237 | + |
| 238 | + modules.append(CheckModule(name=module_name, check_class=check_class, definition=check_definition)) |
253 | 239 |
|
254 | 240 | return modules |
255 | 241 |
|
|
0 commit comments