diff --git a/changelog/+479a6128.removed.md b/changelog/+479a6128.removed.md new file mode 100644 index 00000000..5f932544 --- /dev/null +++ b/changelog/+479a6128.removed.md @@ -0,0 +1 @@ +Removed previously deprecated InfrahubTransform.init() method diff --git a/changelog/81.fixed.md b/changelog/81.fixed.md new file mode 100644 index 00000000..ff47bc90 --- /dev/null +++ b/changelog/81.fixed.md @@ -0,0 +1 @@ +CTL: Fix support for relative imports for transforms and generators diff --git a/infrahub_sdk/_importer.py b/infrahub_sdk/_importer.py index 8d01adc7..071388d6 100644 --- a/infrahub_sdk/_importer.py +++ b/infrahub_sdk/_importer.py @@ -2,20 +2,26 @@ import importlib import sys +from pathlib import Path from typing import TYPE_CHECKING, Optional from .exceptions import ModuleImportError if TYPE_CHECKING: - from pathlib import Path from types import ModuleType +module_mtime_cache: dict[str, float] = {} + def import_module( module_path: Path, import_root: Optional[str] = None, relative_path: Optional[str] = None ) -> ModuleType: import_root = import_root or str(module_path.parent) + file_on_disk = module_path + if import_root and relative_path: + file_on_disk = Path(import_root, relative_path, module_path.name) + if import_root not in sys.path: sys.path.append(import_root) @@ -25,7 +31,21 @@ def import_module( module_name = relative_path.replace("/", ".") + f".{module_name}" try: - module = importlib.import_module(module_name) + if module_name in sys.modules: + module = sys.modules[module_name] + current_mtime = file_on_disk.stat().st_mtime + + if module_name in module_mtime_cache: + last_mtime = module_mtime_cache[module_name] + if current_mtime == last_mtime: + return module + + module_mtime_cache[module_name] = current_mtime + module = importlib.reload(module) + else: + module = importlib.import_module(module_name) + module_mtime_cache[module_name] = file_on_disk.stat().st_mtime + except ModuleNotFoundError as exc: raise ModuleImportError(message=f"{exc!s} ({module_path})") from exc except SyntaxError as exc: diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 324ba6ab..baf5231b 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -36,14 +36,13 @@ parse_cli_vars, ) from ..ctl.validate import app as validate_app -from ..exceptions import GraphQLError, InfrahubTransformNotFoundError +from ..exceptions import GraphQLError, ModuleImportError from ..jinja2 import identify_faulty_jinja_code from ..schema import ( InfrahubRepositoryConfig, MainSchemaTypes, SchemaRoot, ) -from ..transforms import get_transform_class_instance from ..utils import get_branch, write_to_file from ..yaml import SchemaFile from .exporter import dump @@ -322,32 +321,22 @@ def transform( list_transforms(config=repository_config) return - # Load transform config - try: - matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable - if not matched: - raise ValueError(f"{transform_name} does not exist") - except ValueError as exc: - console.print(f"[red]Unable to find requested transform: {transform_name}") - list_transforms(config=repository_config) - raise typer.Exit(1) from exc - - transform_config = matched[0] + transform_config = repository_config.get_python_transform(name=transform_name) # Get client client = initialize_client() # Get python transform class instance + + relative_path = str(transform_config.file_path.parent) if transform_config.file_path.parent != Path() else None + try: - transform = get_transform_class_instance( - transform_config=transform_config, - branch=branch, - client=client, - ) - except InfrahubTransformNotFoundError as exc: - console.print(f"Unable to load {transform_name} from python_transforms") + transform_class = transform_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path) + except ModuleImportError as exc: + console.print(f"[red]{exc.message}") raise typer.Exit(1) from exc + transform = transform_class(client=client, branch=branch) # Get data query_str = repository_config.get_query(name=transform.query).load_query() data = asyncio.run( diff --git a/infrahub_sdk/ctl/generator.py b/infrahub_sdk/ctl/generator.py index c46844b5..5d7b0ebd 100644 --- a/infrahub_sdk/ctl/generator.py +++ b/infrahub_sdk/ctl/generator.py @@ -1,12 +1,14 @@ from pathlib import Path from typing import Optional +import typer from rich.console import Console from ..ctl import config from ..ctl.client import initialize_client from ..ctl.repository import get_repository_config from ..ctl.utils import execute_graphql_query, parse_cli_vars +from ..exceptions import ModuleImportError from ..node import InfrahubNode from ..schema import InfrahubRepositoryConfig @@ -25,17 +27,18 @@ async def run( list_generators(repository_config=repository_config) return - matched = [generator for generator in repository_config.generator_definitions if generator.name == generator_name] # pylint: disable=not-an-iterable + generator_config = repository_config.get_generator_definition(name=generator_name) console = Console() - if not matched: - console.print(f"[red]Unable to find requested generator: {generator_name}") - list_generators(repository_config=repository_config) - return + relative_path = str(generator_config.file_path.parent) if generator_config.file_path.parent != Path() else None + + try: + generator_class = generator_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path) + except ModuleImportError as exc: + console.print(f"[red]{exc.message}") + raise typer.Exit(1) from exc - generator_config = matched[0] - generator_class = generator_config.load_class() variables_dict = parse_cli_vars(variables) param_key = list(generator_config.parameters.keys()) diff --git a/infrahub_sdk/ctl/utils.py b/infrahub_sdk/ctl/utils.py index 43ae4d0e..9221ee7e 100644 --- a/infrahub_sdk/ctl/utils.py +++ b/infrahub_sdk/ctl/utils.py @@ -20,6 +20,7 @@ Error, GraphQLError, NodeNotFoundError, + ResourceNotDefinedError, SchemaNotFoundError, ServerNotReachableError, ServerNotResponsiveError, @@ -59,7 +60,7 @@ def handle_exception(exc: Exception, console: Console, exit_code: int) -> NoRetu if isinstance(exc, GraphQLError): print_graphql_errors(console=console, errors=exc.errors) raise typer.Exit(code=exit_code) - if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError)): + if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError, ResourceNotDefinedError)): console.print(f"[red]Error: {exc!s}") raise typer.Exit(code=exit_code) diff --git a/infrahub_sdk/exceptions.py b/infrahub_sdk/exceptions.py index 1aa92e58..fc823e80 100644 --- a/infrahub_sdk/exceptions.py +++ b/infrahub_sdk/exceptions.py @@ -87,6 +87,14 @@ def __str__(self) -> str: """ +class ResourceNotDefinedError(Error): + """Raised when trying to access a resource that hasn't been defined.""" + + def __init__(self, message: Optional[str] = None): + self.message = message or "The requested resource was not found" + super().__init__(self.message) + + class InfrahubCheckNotFoundError(Error): def __init__(self, name: str, message: Optional[str] = None): self.message = message or f"The requested InfrahubCheck '{name}' was not found." diff --git a/infrahub_sdk/schema.py b/infrahub_sdk/schema.py index 051d5f86..5d3a6bbc 100644 --- a/infrahub_sdk/schema.py +++ b/infrahub_sdk/schema.py @@ -11,9 +11,16 @@ from typing_extensions import TypeAlias from ._importer import import_module -from .exceptions import InvalidResponseError, ModuleImportError, SchemaNotFoundError, ValidationError +from .exceptions import ( + InvalidResponseError, + ModuleImportError, + ResourceNotDefinedError, + SchemaNotFoundError, + ValidationError, +) from .generator import InfrahubGenerator from .graphql import Mutation +from .transforms import InfrahubTransform from .utils import duplicates if TYPE_CHECKING: @@ -120,6 +127,21 @@ class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement): file_path: Path = Field(..., description="The file within the repository with the transform code.") class_name: str = Field(default="Transform", description="The name of the transform class to run.") + def load_class( + self, import_root: Optional[str] = None, relative_path: Optional[str] = None + ) -> type[InfrahubTransform]: + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) + + if self.class_name not in dir(module): + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") + + transform_class = getattr(module, self.class_name) + + if not issubclass(transform_class, InfrahubTransform): + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform") + + return transform_class + class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement): model_config = ConfigDict(extra="forbid") @@ -189,7 +211,7 @@ def _get_resource( for item in getattr(self, RESOURCE_MAP[resource_type]): if getattr(item, resource_field) == resource_id: return item - raise KeyError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") + raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") def has_jinja2_transform(self, name: str) -> bool: return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig) diff --git a/infrahub_sdk/transforms.py b/infrahub_sdk/transforms.py index e9fc43c5..21c2fb73 100644 --- a/infrahub_sdk/transforms.py +++ b/infrahub_sdk/transforms.py @@ -3,18 +3,17 @@ import asyncio import importlib import os -import warnings from abc import abstractmethod from typing import TYPE_CHECKING, Any, Optional from git import Repo -from . import InfrahubClient -from .exceptions import InfrahubTransformNotFoundError +from .exceptions import InfrahubTransformNotFoundError, UninitializedError if TYPE_CHECKING: from pathlib import Path + from . import InfrahubClient from .schema import InfrahubPythonTransformConfig INFRAHUB_TRANSFORM_VARIABLE_TO_IMPORT = "INFRAHUB_TRANSFORMS" @@ -48,25 +47,10 @@ def __init__( @property def client(self) -> InfrahubClient: - if not self._client: - self._client = InfrahubClient(address=self.server_url) + if self._client: + return self._client - return self._client - - @classmethod - async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform: - """Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically.""" - warnings.warn( - f"{cls.__class__.__name__}.init has been deprecated and will be removed in the version after Infrahub SDK 1.0.0", - DeprecationWarning, - stacklevel=1, - ) - if client: - kwargs["client"] = client - - item = cls(*args, **kwargs) - - return item + raise UninitializedError("The client has not been initialized") @property def branch_name(self) -> str: diff --git a/tests/unit/ctl/test_transform_app.py b/tests/unit/ctl/test_transform_app.py index d4fd0316..a9fead5e 100644 --- a/tests/unit/ctl/test_transform_app.py +++ b/tests/unit/ctl/test_transform_app.py @@ -62,7 +62,7 @@ def test_transform_not_exist_in_infrahub_yml(tags_transform_dir: str) -> None: transform_name = "not_existing_transform" with change_directory(tags_transform_dir): output = runner.invoke(app, ["transform", transform_name, "tag=red"]) - assert f"Unable to find requested transform: {transform_name}" in output.stdout + assert f"Unable to find '{transform_name}'" in strip_color(output.stdout) assert output.exit_code == 1 @staticmethod @@ -77,7 +77,7 @@ def test_transform_python_file_not_defined(tags_transform_dir: str) -> None: transform_name = "tags_transform" with change_directory(tags_transform_dir): output = runner.invoke(app, ["transform", transform_name, "tag=red"]) - assert f"Unable to load {transform_name} from python_transforms" in output.stdout + assert "No module named 'tags_transform' (tags_transform.py)" in strip_color(output.stdout) assert output.exit_code == 1 @staticmethod @@ -97,7 +97,7 @@ def test_transform_python_class_not_defined(tags_transform_dir: str) -> None: transform_name = "tags_transform" with change_directory(tags_transform_dir): output = runner.invoke(app, ["transform", transform_name, "tag=red"]) - assert f"Unable to load {transform_name} from python_transforms" in output.stdout + assert "The specified class TagsTransform was not found within the module" in output.stdout assert output.exit_code == 1 @staticmethod