Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/+479a6128.removed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Removed previously deprecated InfrahubTransform.init() method
1 change: 1 addition & 0 deletions changelog/81.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CTL: Fix support for relative imports for transforms and generators
24 changes: 22 additions & 2 deletions infrahub_sdk/_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@

import importlib
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from .exceptions import ModuleImportError

if TYPE_CHECKING:
from pathlib import Path
from types import ModuleType

module_mtime_cache: dict[str, float] = {}


def import_module(
module_path: Path, import_root: Optional[str] = None, relative_path: Optional[str] = None
) -> ModuleType:
import_root = import_root or str(module_path.parent)

file_on_disk = module_path
if import_root and relative_path:
file_on_disk = Path(import_root, relative_path, module_path.name)

Check warning on line 23 in infrahub_sdk/_importer.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/_importer.py#L23

Added line #L23 was not covered by tests

if import_root not in sys.path:
sys.path.append(import_root)

Expand All @@ -25,7 +31,21 @@
module_name = relative_path.replace("/", ".") + f".{module_name}"

try:
module = importlib.import_module(module_name)
if module_name in sys.modules:
module = sys.modules[module_name]
current_mtime = file_on_disk.stat().st_mtime

if module_name in module_mtime_cache:
last_mtime = module_mtime_cache[module_name]
if current_mtime == last_mtime:
return module

module_mtime_cache[module_name] = current_mtime
module = importlib.reload(module)
else:
module = importlib.import_module(module_name)
module_mtime_cache[module_name] = file_on_disk.stat().st_mtime

except ModuleNotFoundError as exc:
raise ModuleImportError(message=f"{exc!s} ({module_path})") from exc
except SyntaxError as exc:
Expand Down
29 changes: 9 additions & 20 deletions infrahub_sdk/ctl/cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@
parse_cli_vars,
)
from ..ctl.validate import app as validate_app
from ..exceptions import GraphQLError, InfrahubTransformNotFoundError
from ..exceptions import GraphQLError, ModuleImportError
from ..jinja2 import identify_faulty_jinja_code
from ..schema import (
InfrahubRepositoryConfig,
MainSchemaTypes,
SchemaRoot,
)
from ..transforms import get_transform_class_instance
from ..utils import get_branch, write_to_file
from ..yaml import SchemaFile
from .exporter import dump
Expand Down Expand Up @@ -322,32 +321,22 @@ def transform(
list_transforms(config=repository_config)
return

# Load transform config
try:
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
if not matched:
raise ValueError(f"{transform_name} does not exist")
except ValueError as exc:
console.print(f"[red]Unable to find requested transform: {transform_name}")
list_transforms(config=repository_config)
raise typer.Exit(1) from exc

transform_config = matched[0]
transform_config = repository_config.get_python_transform(name=transform_name)

# Get client
client = initialize_client()

# Get python transform class instance

relative_path = str(transform_config.file_path.parent) if transform_config.file_path.parent != Path() else None

try:
transform = get_transform_class_instance(
transform_config=transform_config,
branch=branch,
client=client,
)
except InfrahubTransformNotFoundError as exc:
console.print(f"Unable to load {transform_name} from python_transforms")
transform_class = transform_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path)
except ModuleImportError as exc:
console.print(f"[red]{exc.message}")
raise typer.Exit(1) from exc

transform = transform_class(client=client, branch=branch)
# Get data
query_str = repository_config.get_query(name=transform.query).load_query()
data = asyncio.run(
Expand Down
17 changes: 10 additions & 7 deletions infrahub_sdk/ctl/generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from pathlib import Path
from typing import Optional

import typer
from rich.console import Console

from ..ctl import config
from ..ctl.client import initialize_client
from ..ctl.repository import get_repository_config
from ..ctl.utils import execute_graphql_query, parse_cli_vars
from ..exceptions import ModuleImportError
from ..node import InfrahubNode
from ..schema import InfrahubRepositoryConfig

Expand All @@ -25,17 +27,18 @@
list_generators(repository_config=repository_config)
return

matched = [generator for generator in repository_config.generator_definitions if generator.name == generator_name] # pylint: disable=not-an-iterable
generator_config = repository_config.get_generator_definition(name=generator_name)

Check warning on line 30 in infrahub_sdk/ctl/generator.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/ctl/generator.py#L30

Added line #L30 was not covered by tests

console = Console()

if not matched:
console.print(f"[red]Unable to find requested generator: {generator_name}")
list_generators(repository_config=repository_config)
return
relative_path = str(generator_config.file_path.parent) if generator_config.file_path.parent != Path() else None

Check warning on line 34 in infrahub_sdk/ctl/generator.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/ctl/generator.py#L34

Added line #L34 was not covered by tests

try:
generator_class = generator_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path)
except ModuleImportError as exc:
console.print(f"[red]{exc.message}")
raise typer.Exit(1) from exc

Check warning on line 40 in infrahub_sdk/ctl/generator.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/ctl/generator.py#L36-L40

Added lines #L36 - L40 were not covered by tests

generator_config = matched[0]
generator_class = generator_config.load_class()
variables_dict = parse_cli_vars(variables)

param_key = list(generator_config.parameters.keys())
Expand Down
3 changes: 2 additions & 1 deletion infrahub_sdk/ctl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Error,
GraphQLError,
NodeNotFoundError,
ResourceNotDefinedError,
SchemaNotFoundError,
ServerNotReachableError,
ServerNotResponsiveError,
Expand Down Expand Up @@ -59,7 +60,7 @@ def handle_exception(exc: Exception, console: Console, exit_code: int) -> NoRetu
if isinstance(exc, GraphQLError):
print_graphql_errors(console=console, errors=exc.errors)
raise typer.Exit(code=exit_code)
if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError)):
if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError, ResourceNotDefinedError)):
console.print(f"[red]Error: {exc!s}")
raise typer.Exit(code=exit_code)

Expand Down
8 changes: 8 additions & 0 deletions infrahub_sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def __str__(self) -> str:
"""


class ResourceNotDefinedError(Error):
"""Raised when trying to access a resource that hasn't been defined."""

def __init__(self, message: Optional[str] = None):
self.message = message or "The requested resource was not found"
super().__init__(self.message)


class InfrahubCheckNotFoundError(Error):
def __init__(self, name: str, message: Optional[str] = None):
self.message = message or f"The requested InfrahubCheck '{name}' was not found."
Expand Down
26 changes: 24 additions & 2 deletions infrahub_sdk/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
from typing_extensions import TypeAlias

from ._importer import import_module
from .exceptions import InvalidResponseError, ModuleImportError, SchemaNotFoundError, ValidationError
from .exceptions import (
InvalidResponseError,
ModuleImportError,
ResourceNotDefinedError,
SchemaNotFoundError,
ValidationError,
)
from .generator import InfrahubGenerator
from .graphql import Mutation
from .transforms import InfrahubTransform
from .utils import duplicates

if TYPE_CHECKING:
Expand Down Expand Up @@ -120,6 +127,21 @@
file_path: Path = Field(..., description="The file within the repository with the transform code.")
class_name: str = Field(default="Transform", description="The name of the transform class to run.")

def load_class(
self, import_root: Optional[str] = None, relative_path: Optional[str] = None
) -> type[InfrahubTransform]:
module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path)

if self.class_name not in dir(module):
raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module")

transform_class = getattr(module, self.class_name)

if not issubclass(transform_class, InfrahubTransform):
raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform")

Check warning on line 141 in infrahub_sdk/schema.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/schema.py#L141

Added line #L141 was not covered by tests

return transform_class


class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement):
model_config = ConfigDict(extra="forbid")
Expand Down Expand Up @@ -189,7 +211,7 @@
for item in getattr(self, RESOURCE_MAP[resource_type]):
if getattr(item, resource_field) == resource_id:
return item
raise KeyError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}")
raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}")

def has_jinja2_transform(self, name: str) -> bool:
return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig)
Expand Down
26 changes: 5 additions & 21 deletions infrahub_sdk/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import asyncio
import importlib
import os
import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from git import Repo

from . import InfrahubClient
from .exceptions import InfrahubTransformNotFoundError
from .exceptions import InfrahubTransformNotFoundError, UninitializedError

if TYPE_CHECKING:
from pathlib import Path

from . import InfrahubClient
from .schema import InfrahubPythonTransformConfig

INFRAHUB_TRANSFORM_VARIABLE_TO_IMPORT = "INFRAHUB_TRANSFORMS"
Expand Down Expand Up @@ -48,25 +47,10 @@

@property
def client(self) -> InfrahubClient:
if not self._client:
self._client = InfrahubClient(address=self.server_url)
if self._client:
return self._client

return self._client

@classmethod
async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform:
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
warnings.warn(
f"{cls.__class__.__name__}.init has been deprecated and will be removed in the version after Infrahub SDK 1.0.0",
DeprecationWarning,
stacklevel=1,
)
if client:
kwargs["client"] = client

item = cls(*args, **kwargs)

return item
raise UninitializedError("The client has not been initialized")

Check warning on line 53 in infrahub_sdk/transforms.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/transforms.py#L53

Added line #L53 was not covered by tests

@property
def branch_name(self) -> str:
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/ctl/test_transform_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_transform_not_exist_in_infrahub_yml(tags_transform_dir: str) -> None:
transform_name = "not_existing_transform"
with change_directory(tags_transform_dir):
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
assert f"Unable to find requested transform: {transform_name}" in output.stdout
assert f"Unable to find '{transform_name}'" in strip_color(output.stdout)
assert output.exit_code == 1

@staticmethod
Expand All @@ -77,7 +77,7 @@ def test_transform_python_file_not_defined(tags_transform_dir: str) -> None:
transform_name = "tags_transform"
with change_directory(tags_transform_dir):
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
assert f"Unable to load {transform_name} from python_transforms" in output.stdout
assert "No module named 'tags_transform' (tags_transform.py)" in strip_color(output.stdout)
assert output.exit_code == 1

@staticmethod
Expand All @@ -97,7 +97,7 @@ def test_transform_python_class_not_defined(tags_transform_dir: str) -> None:
transform_name = "tags_transform"
with change_directory(tags_transform_dir):
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
assert f"Unable to load {transform_name} from python_transforms" in output.stdout
assert "The specified class TagsTransform was not found within the module" in output.stdout
assert output.exit_code == 1

@staticmethod
Expand Down