diff --git a/changelog/+479a6128.removed.md b/changelog/+479a6128.removed.md new file mode 100644 index 00000000..5f932544 --- /dev/null +++ b/changelog/+479a6128.removed.md @@ -0,0 +1 @@ +Removed previously deprecated InfrahubTransform.init() method diff --git a/changelog/+89d1d0b7.deprecated.md b/changelog/+89d1d0b7.deprecated.md new file mode 100644 index 00000000..6660624a --- /dev/null +++ b/changelog/+89d1d0b7.deprecated.md @@ -0,0 +1 @@ +Marked InfrahubCheck.init() as deprecated and scheduled to be removed in Infrahub SDK 2.0.0 diff --git a/changelog/27.fixed.md b/changelog/27.fixed.md new file mode 100644 index 00000000..b62a31ab --- /dev/null +++ b/changelog/27.fixed.md @@ -0,0 +1 @@ +Fix generated GraphQL query when having a relationship to a pool node \ No newline at end of file diff --git a/changelog/81.fixed.md b/changelog/81.fixed.md new file mode 100644 index 00000000..ff47bc90 --- /dev/null +++ b/changelog/81.fixed.md @@ -0,0 +1 @@ +CTL: Fix support for relative imports for transforms and generators diff --git a/infrahub_sdk/_importer.py b/infrahub_sdk/_importer.py index 8d01adc7..071388d6 100644 --- a/infrahub_sdk/_importer.py +++ b/infrahub_sdk/_importer.py @@ -2,20 +2,26 @@ import importlib import sys +from pathlib import Path from typing import TYPE_CHECKING, Optional from .exceptions import ModuleImportError if TYPE_CHECKING: - from pathlib import Path from types import ModuleType +module_mtime_cache: dict[str, float] = {} + def import_module( module_path: Path, import_root: Optional[str] = None, relative_path: Optional[str] = None ) -> ModuleType: import_root = import_root or str(module_path.parent) + file_on_disk = module_path + if import_root and relative_path: + file_on_disk = Path(import_root, relative_path, module_path.name) + if import_root not in sys.path: sys.path.append(import_root) @@ -25,7 +31,21 @@ def import_module( module_name = relative_path.replace("/", ".") + f".{module_name}" try: - module = importlib.import_module(module_name) + if module_name in sys.modules: + module = sys.modules[module_name] + current_mtime = file_on_disk.stat().st_mtime + + if module_name in module_mtime_cache: + last_mtime = module_mtime_cache[module_name] + if current_mtime == last_mtime: + return module + + module_mtime_cache[module_name] = current_mtime + module = importlib.reload(module) + else: + module = importlib.import_module(module_name) + module_mtime_cache[module_name] = file_on_disk.stat().st_mtime + except ModuleNotFoundError as exc: raise ModuleImportError(message=f"{exc!s} ({module_path})") from exc except SyntaxError as exc: diff --git a/infrahub_sdk/checks.py b/infrahub_sdk/checks.py index a873893c..85f9ca1e 100644 --- a/infrahub_sdk/checks.py +++ b/infrahub_sdk/checks.py @@ -3,6 +3,7 @@ import asyncio import importlib import os +import warnings from abc import abstractmethod from typing import TYPE_CHECKING, Any, Optional @@ -10,16 +11,18 @@ from git.repo import Repo from pydantic import BaseModel, Field -from . import InfrahubClient from .exceptions import InfrahubCheckNotFoundError, UninitializedError if TYPE_CHECKING: from pathlib import Path + from . import InfrahubClient from .schema import InfrahubCheckDefinitionConfig INFRAHUB_CHECK_VARIABLE_TO_IMPORT = "INFRAHUB_CHECKS" +_client_class = "InfrahubClient" + class InfrahubCheckInitializer(BaseModel): """Information about the originator of the check.""" @@ -81,11 +84,17 @@ def client(self, value: InfrahubClient) -> None: @classmethod async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubCheck: """Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically.""" - - instance = cls(*args, **kwargs) - instance.client = client or InfrahubClient() - - return instance + warnings.warn( + "InfrahubCheck.init has been deprecated and will be removed in the version in Infrahub SDK 2.0.0", + DeprecationWarning, + stacklevel=1, + ) + if not client: + client_module = importlib.import_module("infrahub_sdk.client") + client_class = getattr(client_module, _client_class) + client = client_class() + kwargs["client"] = client + return cls(*args, **kwargs) @property def errors(self) -> list[dict[str, Any]]: diff --git a/infrahub_sdk/ctl/check.py b/infrahub_sdk/ctl/check.py index 99c547fe..a5164173 100644 --- a/infrahub_sdk/ctl/check.py +++ b/infrahub_sdk/ctl/check.py @@ -1,11 +1,9 @@ -import importlib import logging import sys from asyncio import run as aiorun from dataclasses import dataclass from pathlib import Path -from types import ModuleType -from typing import Optional +from typing import Optional, Type import typer from rich.console import Console @@ -18,6 +16,7 @@ from ..ctl.exceptions import QueryNotFoundError from ..ctl.repository import get_repository_config from ..ctl.utils import catch_exception, execute_graphql_query +from ..exceptions import ModuleImportError from ..schema import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig app = typer.Typer() @@ -27,12 +26,9 @@ @dataclass class CheckModule: name: str - module: ModuleType + check_class: Type[InfrahubCheck] definition: InfrahubCheckDefinitionConfig - def get_check(self) -> InfrahubCheck: - return getattr(self.module, self.definition.class_name) - @app.callback() def callback() -> None: @@ -67,11 +63,7 @@ def run( check_definitions = repository_config.check_definitions if name: - check_definitions = [check for check in repository_config.check_definitions if check.name == name] # pylint: disable=not-an-iterable - if not check_definitions: - console.print(f"[red]Unable to find requested transform: {name}") - list_checks(repository_config=repository_config) - return + check_definitions = [repository_config.get_check_definition(name=name)] check_modules = get_modules(check_definitions=check_definitions) aiorun( @@ -99,8 +91,8 @@ async def run_check( output = "stdout" if format_json else None log = logging.getLogger("infrahub") passed = True - check_class = check_module.get_check() - check = await check_class.init(client=client, params=params, output=output, root_directory=path, branch=branch) + check_class = check_module.check_class + check = check_class(client=client, params=params, output=output, root_directory=path, branch=branch) param_log = f" - {params}" if params else "" try: data = execute_graphql_query( @@ -231,25 +223,19 @@ async def run_checks( def get_modules(check_definitions: list[InfrahubCheckDefinitionConfig]) -> list[CheckModule]: - log = logging.getLogger("infrahub") modules = [] for check_definition in check_definitions: - directory_name = str(check_definition.file_path.parent) module_name = check_definition.file_path.stem - if directory_name not in sys.path: - sys.path.append(directory_name) + relative_path = str(check_definition.file_path.parent) if check_definition.file_path.parent != Path() else None try: - module = importlib.import_module(module_name) - except ModuleNotFoundError: - log.error(f"Unable to load {check_definition.file_path}") - continue - - if check_definition.class_name not in dir(module): - log.error(f"{check_definition.class_name} class not found in {check_definition.file_path}") - continue - modules.append(CheckModule(name=module_name, module=module, definition=check_definition)) + check_class = check_definition.load_class(import_root=str(Path.cwd()), relative_path=relative_path) + except ModuleImportError as exc: + console.print(f"[red]{exc.message}") + raise typer.Exit(1) from exc + + modules.append(CheckModule(name=module_name, check_class=check_class, definition=check_definition)) return modules diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 324ba6ab..baf5231b 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -36,14 +36,13 @@ parse_cli_vars, ) from ..ctl.validate import app as validate_app -from ..exceptions import GraphQLError, InfrahubTransformNotFoundError +from ..exceptions import GraphQLError, ModuleImportError from ..jinja2 import identify_faulty_jinja_code from ..schema import ( InfrahubRepositoryConfig, MainSchemaTypes, SchemaRoot, ) -from ..transforms import get_transform_class_instance from ..utils import get_branch, write_to_file from ..yaml import SchemaFile from .exporter import dump @@ -322,32 +321,22 @@ def transform( list_transforms(config=repository_config) return - # Load transform config - try: - matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable - if not matched: - raise ValueError(f"{transform_name} does not exist") - except ValueError as exc: - console.print(f"[red]Unable to find requested transform: {transform_name}") - list_transforms(config=repository_config) - raise typer.Exit(1) from exc - - transform_config = matched[0] + transform_config = repository_config.get_python_transform(name=transform_name) # Get client client = initialize_client() # Get python transform class instance + + relative_path = str(transform_config.file_path.parent) if transform_config.file_path.parent != Path() else None + try: - transform = get_transform_class_instance( - transform_config=transform_config, - branch=branch, - client=client, - ) - except InfrahubTransformNotFoundError as exc: - console.print(f"Unable to load {transform_name} from python_transforms") + transform_class = transform_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path) + except ModuleImportError as exc: + console.print(f"[red]{exc.message}") raise typer.Exit(1) from exc + transform = transform_class(client=client, branch=branch) # Get data query_str = repository_config.get_query(name=transform.query).load_query() data = asyncio.run( diff --git a/infrahub_sdk/ctl/generator.py b/infrahub_sdk/ctl/generator.py index 26fd6db0..414ab879 100644 --- a/infrahub_sdk/ctl/generator.py +++ b/infrahub_sdk/ctl/generator.py @@ -1,12 +1,14 @@ from pathlib import Path from typing import Optional +import typer from rich.console import Console from ..ctl import config from ..ctl.client import initialize_client from ..ctl.repository import get_repository_config from ..ctl.utils import execute_graphql_query, parse_cli_vars +from ..exceptions import ModuleImportError from ..node import InfrahubNode from ..schema import InfrahubRepositoryConfig @@ -18,24 +20,25 @@ async def run( list_available: bool, branch: Optional[str] = None, variables: Optional[list[str]] = None, -): # pylint: disable=unused-argument +) -> None: # pylint: disable=unused-argument repository_config = get_repository_config(Path(config.INFRAHUB_REPO_CONFIG_FILE)) if list_available: list_generators(repository_config=repository_config) return - matched = [generator for generator in repository_config.generator_definitions if generator.name == generator_name] # pylint: disable=not-an-iterable + generator_config = repository_config.get_generator_definition(name=generator_name) console = Console() - if not matched: - console.print(f"[red]Unable to find requested generator: {generator_name}") - list_generators(repository_config=repository_config) - return + relative_path = str(generator_config.file_path.parent) if generator_config.file_path.parent != Path() else None + + try: + generator_class = generator_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path) + except ModuleImportError as exc: + console.print(f"[red]{exc.message}") + raise typer.Exit(1) from exc - generator_config = matched[0] - generator_class = generator_config.load_class() variables_dict = parse_cli_vars(variables) param_key = list(generator_config.parameters.keys()) @@ -69,6 +72,13 @@ async def run( kind="CoreGroup", branch=branch, include=["members"], name__value=generator_config.targets ) await targets.members.fetch() + + if not targets.members.peers: + console.print( + f"[red]No members found within '{generator_config.targets}', not running generator '{generator_name}'" + ) + return + for member in targets.members.peers: check_parameter = {} if identifier: diff --git a/infrahub_sdk/ctl/utils.py b/infrahub_sdk/ctl/utils.py index 43ae4d0e..9221ee7e 100644 --- a/infrahub_sdk/ctl/utils.py +++ b/infrahub_sdk/ctl/utils.py @@ -20,6 +20,7 @@ Error, GraphQLError, NodeNotFoundError, + ResourceNotDefinedError, SchemaNotFoundError, ServerNotReachableError, ServerNotResponsiveError, @@ -59,7 +60,7 @@ def handle_exception(exc: Exception, console: Console, exit_code: int) -> NoRetu if isinstance(exc, GraphQLError): print_graphql_errors(console=console, errors=exc.errors) raise typer.Exit(code=exit_code) - if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError)): + if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError, ResourceNotDefinedError)): console.print(f"[red]Error: {exc!s}") raise typer.Exit(code=exit_code) diff --git a/infrahub_sdk/exceptions.py b/infrahub_sdk/exceptions.py index 1aa92e58..fc823e80 100644 --- a/infrahub_sdk/exceptions.py +++ b/infrahub_sdk/exceptions.py @@ -87,6 +87,14 @@ def __str__(self) -> str: """ +class ResourceNotDefinedError(Error): + """Raised when trying to access a resource that hasn't been defined.""" + + def __init__(self, message: Optional[str] = None): + self.message = message or "The requested resource was not found" + super().__init__(self.message) + + class InfrahubCheckNotFoundError(Error): def __init__(self, name: str, message: Optional[str] = None): self.message = message or f"The requested InfrahubCheck '{name}' was not found." diff --git a/infrahub_sdk/groups.py b/infrahub_sdk/groups.py new file mode 100644 index 00000000..f3fa341b --- /dev/null +++ b/infrahub_sdk/groups.py @@ -0,0 +1,28 @@ +from typing import List + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.node import InfrahubNode + + +async def group_add_subscriber( + client: InfrahubClient, group: InfrahubNode, subscribers: List[str], branch: str +) -> dict: + subscribers_str = ["{ id: " + f'"{subscriber}"' + " }" for subscriber in subscribers] + query = """ + mutation { + RelationshipAdd( + data: { + id: "%s", + name: "subscribers", + nodes: [ %s ] + } + ) { + ok + } + } + """ % ( + group.id, + ", ".join(subscribers_str), + ) + + return await client.execute_graphql(query=query, branch_name=branch, tracker="mutation-relationshipadd") diff --git a/infrahub_sdk/node.py b/infrahub_sdk/node.py index 38fb6742..bd495811 100644 --- a/infrahub_sdk/node.py +++ b/infrahub_sdk/node.py @@ -241,10 +241,10 @@ def typename(self) -> Optional[str]: return self._peer.typename return self._typename - def _generate_input_data(self) -> dict[str, Any]: + def _generate_input_data(self, allocate_from_pool: bool = False) -> dict[str, Any]: data: dict[str, Any] = {} - if self.is_resource_pool: + if self.is_resource_pool and allocate_from_pool: return {"from_pool": {"id": self.id}} if self.id is not None: @@ -424,8 +424,8 @@ def peer_hfids_str(self) -> list[str]: def has_update(self) -> bool: return self._has_update - def _generate_input_data(self) -> list[dict]: - return [peer._generate_input_data() for peer in self.peers] + def _generate_input_data(self, allocate_from_pool: bool = False) -> list[dict]: + return [peer._generate_input_data(allocate_from_pool=allocate_from_pool) for peer in self.peers] def _generate_mutation_query(self) -> dict[str, Any]: # Does nothing for now @@ -818,6 +818,7 @@ def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: b data[item_name] = attr_data for item_name in self._relationships: + allocate_from_pool = False rel_schema = self._schema.get_relationship(name=item_name) if not rel_schema or rel_schema.read_only: continue @@ -836,7 +837,12 @@ def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: b if rel is None or not rel.initialized: continue - rel_data = rel._generate_input_data() + if isinstance(rel, (RelatedNode, RelatedNodeSync)) and rel.is_resource_pool: + # If the relatiionship is a resource pool and the expected schema is different from the one of the pool, this means we expect to get + # a resource from the pool itself + allocate_from_pool = rel_schema.peer != rel.peer._schema.kind + + rel_data = rel._generate_input_data(allocate_from_pool=allocate_from_pool) if rel_data and isinstance(rel_data, dict): if variable_values := rel_data.get("data"): @@ -1426,7 +1432,7 @@ async def get_pool_allocated_resources(self, resource: InfrahubNode) -> list[Inf list[InfrahubNode]: The allocated nodes. """ if not self.is_resource_pool(): - raise ValueError("Allocate resources can only be fetched from resource pool nodes.") + raise ValueError("Allocated resources can only be fetched from resource pool nodes.") graphql_query_name = "InfrahubResourcePoolAllocated" node_ids_per_kind: dict[str, list[str]] = {} diff --git a/infrahub_sdk/schema.py b/infrahub_sdk/schema.py index 051d5f86..d8711421 100644 --- a/infrahub_sdk/schema.py +++ b/infrahub_sdk/schema.py @@ -11,9 +11,17 @@ from typing_extensions import TypeAlias from ._importer import import_module -from .exceptions import InvalidResponseError, ModuleImportError, SchemaNotFoundError, ValidationError +from .checks import InfrahubCheck +from .exceptions import ( + InvalidResponseError, + ModuleImportError, + ResourceNotDefinedError, + SchemaNotFoundError, + ValidationError, +) from .generator import InfrahubGenerator from .graphql import Mutation +from .transforms import InfrahubTransform from .utils import duplicates if TYPE_CHECKING: @@ -82,6 +90,19 @@ class InfrahubCheckDefinitionConfig(InfrahubRepositoryConfigElement): ) class_name: str = Field(default="Check", description="The name of the check class to run.") + def load_class(self, import_root: Optional[str] = None, relative_path: Optional[str] = None) -> type[InfrahubCheck]: + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) + + if self.class_name not in dir(module): + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") + + check_class = getattr(module, self.class_name) + + if not issubclass(check_class, InfrahubCheck): + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Check") + + return check_class + class InfrahubGeneratorDefinitionConfig(InfrahubRepositoryConfigElement): model_config = ConfigDict(extra="forbid") @@ -120,6 +141,21 @@ class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement): file_path: Path = Field(..., description="The file within the repository with the transform code.") class_name: str = Field(default="Transform", description="The name of the transform class to run.") + def load_class( + self, import_root: Optional[str] = None, relative_path: Optional[str] = None + ) -> type[InfrahubTransform]: + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) + + if self.class_name not in dir(module): + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") + + transform_class = getattr(module, self.class_name) + + if not issubclass(transform_class, InfrahubTransform): + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform") + + return transform_class + class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement): model_config = ConfigDict(extra="forbid") @@ -189,7 +225,7 @@ def _get_resource( for item in getattr(self, RESOURCE_MAP[resource_type]): if getattr(item, resource_field) == resource_id: return item - raise KeyError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") + raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") def has_jinja2_transform(self, name: str) -> bool: return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig) @@ -291,6 +327,7 @@ class AttributeSchema(BaseModel): max_length: Optional[int] = None min_length: Optional[int] = None regex: Optional[str] = None + order_weight: Optional[int] = None class RelationshipSchema(BaseModel): @@ -308,6 +345,7 @@ class RelationshipSchema(BaseModel): optional: bool = True read_only: bool = False filters: list[FilterSchema] = Field(default_factory=list) + order_weight: Optional[int] = None class BaseNodeSchema(BaseModel): diff --git a/infrahub_sdk/store.py b/infrahub_sdk/store.py index 4bf5e0cb..fe5ba48c 100644 --- a/infrahub_sdk/store.py +++ b/infrahub_sdk/store.py @@ -58,7 +58,7 @@ def _get(self, key: str, kind: Optional[Union[str, type[SchemaType]]] = None, ra if kind_name and kind_name in self._store and key in self._store[kind_name]: # type: ignore[attr-defined] return self._store[kind_name][key] # type: ignore[attr-defined] - for _, item in self._store.items(): # type: ignore[attr-defined] + for item in self._store.values(): # type: ignore[attr-defined] if key in item: return item[key] diff --git a/infrahub_sdk/transforms.py b/infrahub_sdk/transforms.py index e9fc43c5..21c2fb73 100644 --- a/infrahub_sdk/transforms.py +++ b/infrahub_sdk/transforms.py @@ -3,18 +3,17 @@ import asyncio import importlib import os -import warnings from abc import abstractmethod from typing import TYPE_CHECKING, Any, Optional from git import Repo -from . import InfrahubClient -from .exceptions import InfrahubTransformNotFoundError +from .exceptions import InfrahubTransformNotFoundError, UninitializedError if TYPE_CHECKING: from pathlib import Path + from . import InfrahubClient from .schema import InfrahubPythonTransformConfig INFRAHUB_TRANSFORM_VARIABLE_TO_IMPORT = "INFRAHUB_TRANSFORMS" @@ -48,25 +47,10 @@ def __init__( @property def client(self) -> InfrahubClient: - if not self._client: - self._client = InfrahubClient(address=self.server_url) + if self._client: + return self._client - return self._client - - @classmethod - async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform: - """Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically.""" - warnings.warn( - f"{cls.__class__.__name__}.init has been deprecated and will be removed in the version after Infrahub SDK 1.0.0", - DeprecationWarning, - stacklevel=1, - ) - if client: - kwargs["client"] = client - - item = cls(*args, **kwargs) - - return item + raise UninitializedError("The client has not been initialized") @property def branch_name(self) -> str: diff --git a/pyproject.toml b/pyproject.toml index 46aad3ab..846f4645 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "infrahub-sdk" -version = "1.0.0" +version = "1.0.1" description = "Python Client to interact with Infrahub" authors = ["OpsMill "] readme = "README.md" @@ -153,19 +153,33 @@ disallow_untyped_defs = true [[tool.mypy.overrides]] module = "infrahub_sdk.ctl.check" -ignore_errors = true +disable_error_code = [ + "call-overload" +] [[tool.mypy.overrides]] module = "infrahub_sdk.ctl.generator" -ignore_errors = true +disable_error_code = [ + "attr-defined", +] [[tool.mypy.overrides]] module = "infrahub_sdk.ctl.schema" -ignore_errors = true +disable_error_code = [ + "arg-type", + "attr-defined", + "misc", + "union-attr", +] [[tool.mypy.overrides]] module = "infrahub_sdk.utils" -ignore_errors = true +disable_error_code = [ + "arg-type", + "attr-defined", + "return-value", + "union-attr", +] [tool.ruff] line-length = 120 @@ -234,7 +248,6 @@ ignore = [ "FURB177", # Prefer `Path.cwd()` over `Path().resolve()` for current-directory lookups "N802", # Function name should be lowercase "N806", # Variable in function should be lowercase - "PERF102", # When using only the values of a dict use the `values()` method "PERF203", # `try`-`except` within a loop incurs performance overhead "PERF401", # Use a list comprehension to create a transformed list "PLC0206", # Extracting value from dictionary without calling `.items()` diff --git a/tests/integration/utils.py b/tests/helpers/utils.py similarity index 100% rename from tests/integration/utils.py rename to tests/helpers/utils.py diff --git a/tests/integration/test_infrahubctl.py b/tests/unit/ctl/test_transform_app.py similarity index 80% rename from tests/integration/test_infrahubctl.py rename to tests/unit/ctl/test_transform_app.py index a63514f1..a9fead5e 100644 --- a/tests/integration/test_infrahubctl.py +++ b/tests/unit/ctl/test_transform_app.py @@ -3,25 +3,27 @@ import json import os import shutil +import sys import tempfile from pathlib import Path import pytest from git import Repo from pytest_httpx._httpx_mock import HTTPXMock -from typer.testing import Any, CliRunner +from typer.testing import CliRunner from infrahub_sdk.ctl.cli_commands import app - -from .utils import change_directory, strip_color +from tests.helpers.utils import change_directory, strip_color runner = CliRunner() -FIXTURE_BASE_DIR = Path(Path(os.path.abspath(__file__)).parent / ".." / "fixtures" / "integration" / "test_infrahubctl") +FIXTURE_BASE_DIR = Path( + Path(os.path.abspath(__file__)).parent / ".." / ".." / "fixtures" / "integration" / "test_infrahubctl" +) -def read_fixture(file_name: str, fixture_subdir: str = ".") -> Any: +def read_fixture(file_name: str, fixture_subdir: str = ".") -> str: """Read the contents of a fixture.""" with Path(FIXTURE_BASE_DIR / fixture_subdir / file_name).open("r", encoding="utf-8") as fhd: fixture_contents = fhd.read() @@ -38,7 +40,7 @@ def tags_transform_dir(): shutil.copytree(fixture_path, temp_dir, dirs_exist_ok=True) # Initialize fixture as git repo. This is necessary to run some infrahubctl commands. with change_directory(temp_dir): - Repo.init(".") + Repo.init(".", initial_branch="main") yield temp_dir @@ -60,10 +62,11 @@ def test_transform_not_exist_in_infrahub_yml(tags_transform_dir: str) -> None: transform_name = "not_existing_transform" with change_directory(tags_transform_dir): output = runner.invoke(app, ["transform", transform_name, "tag=red"]) - assert f"Unable to find requested transform: {transform_name}" in output.stdout + assert f"Unable to find '{transform_name}'" in strip_color(output.stdout) assert output.exit_code == 1 @staticmethod + @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_transform_python_file_not_defined(tags_transform_dir: str) -> None: """Case transform python file not defined.""" # Remove transform file @@ -74,10 +77,11 @@ def test_transform_python_file_not_defined(tags_transform_dir: str) -> None: transform_name = "tags_transform" with change_directory(tags_transform_dir): output = runner.invoke(app, ["transform", transform_name, "tag=red"]) - assert f"Unable to load {transform_name} from python_transforms" in output.stdout + assert "No module named 'tags_transform' (tags_transform.py)" in strip_color(output.stdout) assert output.exit_code == 1 @staticmethod + @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_transform_python_class_not_defined(tags_transform_dir: str) -> None: """Case transform python class not defined.""" # Rename transform inside of python file so the class name searched for no longer exists @@ -93,10 +97,11 @@ def test_transform_python_class_not_defined(tags_transform_dir: str) -> None: transform_name = "tags_transform" with change_directory(tags_transform_dir): output = runner.invoke(app, ["transform", transform_name, "tag=red"]) - assert f"Unable to load {transform_name} from python_transforms" in output.stdout + assert "The specified class TagsTransform was not found within the module" in output.stdout assert output.exit_code == 1 @staticmethod + @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_gql_query_not_defined(tags_transform_dir: str) -> None: """Case GraphQL Query is not defined""" # Remove GraphQL Query file @@ -110,6 +115,7 @@ def test_gql_query_not_defined(tags_transform_dir: str) -> None: assert output.exit_code == 1 @staticmethod + @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_infrahubctl_transform_cmd_success(httpx_mock: HTTPXMock, tags_transform_dir: str) -> None: """Case infrahubctl transform command executes successfully""" httpx_mock.add_response( diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index d39be706..3428179f 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -821,7 +821,15 @@ async def simple_device_schema() -> NodeSchema: "optional": True, "cardinality": "one", "kind": "Attribute", - } + }, + { + "name": "ip_address_pool", + "peer": "CoreIPAddressPool", + "label": "Address allocator", + "optional": True, + "cardinality": "one", + "kind": "Attribute", + }, ], } return NodeSchema(**data) # type: ignore diff --git a/tests/unit/sdk/test_node.py b/tests/unit/sdk/test_node.py index 12ab305b..543c8992 100644 --- a/tests/unit/sdk/test_node.py +++ b/tests/unit/sdk/test_node.py @@ -1487,7 +1487,9 @@ async def test_create_input_data_with_resource_pool_relationship( }, ) device = InfrahubNode( - client=client, schema=simple_device_schema, data={"name": "device-01", "primary_address": ip_pool} + client=client, + schema=simple_device_schema, + data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool}, ) else: ip_prefix = InfrahubNodeSync(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data) @@ -1504,13 +1506,16 @@ async def test_create_input_data_with_resource_pool_relationship( }, ) device = InfrahubNode( - client=client, schema=simple_device_schema, data={"name": "device-01", "primary_address": ip_pool} + client=client, + schema=simple_device_schema, + data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool}, ) assert device._generate_input_data()["data"] == { "data": { "name": {"value": "device-01"}, "primary_address": {"from_pool": {"id": "pppppppp-pppp-pppp-pppp-pppppppppppp"}}, + "ip_address_pool": {"id": "pppppppp-pppp-pppp-pppp-pppppppppppp"}, }, } @@ -1534,7 +1539,9 @@ async def test_create_mutation_query_with_resource_pool_relationship( }, ) device = InfrahubNode( - client=client, schema=simple_device_schema, data={"name": "device-01", "primary_address": ip_pool} + client=client, + schema=simple_device_schema, + data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool}, ) else: ip_prefix = InfrahubNodeSync(client=client, schema=ipam_ipprefix_schema, data=ipam_ipprefix_data) @@ -1551,11 +1558,17 @@ async def test_create_mutation_query_with_resource_pool_relationship( }, ) device = InfrahubNode( - client=client, schema=simple_device_schema, data={"name": "device-01", "primary_address": ip_pool} + client=client, + schema=simple_device_schema, + data={"name": "device-01", "primary_address": ip_pool, "ip_address_pool": ip_pool}, ) assert device._generate_mutation_query() == { - "object": {"id": None, "primary_address": {"node": {"__typename": None, "display_label": None, "id": None}}}, + "object": { + "id": None, + "primary_address": {"node": {"__typename": None, "display_label": None, "id": None}}, + "ip_address_pool": {"node": {"__typename": None, "display_label": None, "id": None}}, + }, "ok": None, }