Skip to content

Commit 4a9a64f

Browse files
committed
Fix relative imports for Python checks
1 parent 386440d commit 4a9a64f

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

infrahub_sdk/ctl/check.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import importlib
21
import logging
32
import sys
43
from asyncio import run as aiorun
54
from dataclasses import dataclass
65
from pathlib import Path
7-
from types import ModuleType
86
from typing import Optional, Type
97

108
import typer
@@ -18,6 +16,7 @@
1816
from ..ctl.exceptions import QueryNotFoundError
1917
from ..ctl.repository import get_repository_config
2018
from ..ctl.utils import catch_exception, execute_graphql_query
19+
from ..exceptions import ModuleImportError
2120
from ..schema import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig
2221

2322
app = typer.Typer()
@@ -27,12 +26,9 @@
2726
@dataclass
2827
class CheckModule:
2928
name: str
30-
module: ModuleType
29+
check_class: Type[InfrahubCheck]
3130
definition: InfrahubCheckDefinitionConfig
3231

33-
def get_check(self) -> Type[InfrahubCheck]:
34-
return getattr(self.module, self.definition.class_name)
35-
3632

3733
@app.callback()
3834
def callback() -> None:
@@ -67,11 +63,7 @@ def run(
6763

6864
check_definitions = repository_config.check_definitions
6965
if name:
70-
check_definitions = [check for check in repository_config.check_definitions if check.name == name] # pylint: disable=not-an-iterable
71-
if not check_definitions:
72-
console.print(f"[red]Unable to find requested transform: {name}")
73-
list_checks(repository_config=repository_config)
74-
return
66+
check_definitions = [repository_config.get_check_definition(name=name)]
7567

7668
check_modules = get_modules(check_definitions=check_definitions)
7769
aiorun(
@@ -99,7 +91,7 @@ async def run_check(
9991
output = "stdout" if format_json else None
10092
log = logging.getLogger("infrahub")
10193
passed = True
102-
check_class = check_module.get_check()
94+
check_class = check_module.check_class
10395
check = check_class(client=client, params=params, output=output, root_directory=path, branch=branch)
10496
param_log = f" - {params}" if params else ""
10597
try:
@@ -231,25 +223,19 @@ async def run_checks(
231223

232224

233225
def get_modules(check_definitions: list[InfrahubCheckDefinitionConfig]) -> list[CheckModule]:
234-
log = logging.getLogger("infrahub")
235226
modules = []
236227
for check_definition in check_definitions:
237-
directory_name = str(check_definition.file_path.parent)
238228
module_name = check_definition.file_path.stem
239229

240-
if directory_name not in sys.path:
241-
sys.path.append(directory_name)
230+
relative_path = str(check_definition.file_path.parent) if check_definition.file_path.parent != Path() else None
242231

243232
try:
244-
module = importlib.import_module(module_name)
245-
except ModuleNotFoundError:
246-
log.error(f"Unable to load {check_definition.file_path}")
247-
continue
248-
249-
if check_definition.class_name not in dir(module):
250-
log.error(f"{check_definition.class_name} class not found in {check_definition.file_path}")
251-
continue
252-
modules.append(CheckModule(name=module_name, module=module, definition=check_definition))
233+
check_class = check_definition.load_class(import_root=str(Path.cwd()), relative_path=relative_path)
234+
except ModuleImportError as exc:
235+
console.print(f"[red]{exc.message}")
236+
raise typer.Exit(1) from exc
237+
238+
modules.append(CheckModule(name=module_name, check_class=check_class, definition=check_definition))
253239

254240
return modules
255241

infrahub_sdk/schema.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing_extensions import TypeAlias
1212

1313
from ._importer import import_module
14+
from .checks import InfrahubCheck
1415
from .exceptions import (
1516
InvalidResponseError,
1617
ModuleImportError,
@@ -89,6 +90,19 @@ class InfrahubCheckDefinitionConfig(InfrahubRepositoryConfigElement):
8990
)
9091
class_name: str = Field(default="Check", description="The name of the check class to run.")
9192

93+
def load_class(self, import_root: Optional[str] = None, relative_path: Optional[str] = None) -> type[InfrahubCheck]:
94+
module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path)
95+
96+
if self.class_name not in dir(module):
97+
raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module")
98+
99+
check_class = getattr(module, self.class_name)
100+
101+
if not issubclass(check_class, InfrahubCheck):
102+
raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Check")
103+
104+
return check_class
105+
92106

93107
class InfrahubGeneratorDefinitionConfig(InfrahubRepositoryConfigElement):
94108
model_config = ConfigDict(extra="forbid")

0 commit comments

Comments
 (0)