Skip to content

Commit beb14a2

Browse files
committed
Fix relative module imports
Fixes #81
1 parent 758f6ac commit beb14a2

File tree

10 files changed

+85
-56
lines changed

10 files changed

+85
-56
lines changed

changelog/+479a6128.removed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Removed previously deprecated InfrahubTransform.init() method

changelog/81.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CTL: Fix support for relative imports for transforms and generators

infrahub_sdk/_importer.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,26 @@
22

33
import importlib
44
import sys
5+
from pathlib import Path
56
from typing import TYPE_CHECKING, Optional
67

78
from .exceptions import ModuleImportError
89

910
if TYPE_CHECKING:
10-
from pathlib import Path
1111
from types import ModuleType
1212

13+
module_mtime_cache: dict[str, float] = {}
14+
1315

1416
def import_module(
1517
module_path: Path, import_root: Optional[str] = None, relative_path: Optional[str] = None
1618
) -> ModuleType:
1719
import_root = import_root or str(module_path.parent)
1820

21+
file_on_disk = module_path
22+
if import_root and relative_path:
23+
file_on_disk = Path(import_root, relative_path, module_path.name)
24+
1925
if import_root not in sys.path:
2026
sys.path.append(import_root)
2127

@@ -25,7 +31,21 @@ def import_module(
2531
module_name = relative_path.replace("/", ".") + f".{module_name}"
2632

2733
try:
28-
module = importlib.import_module(module_name)
34+
if module_name in sys.modules:
35+
module = sys.modules[module_name]
36+
current_mtime = file_on_disk.stat().st_mtime
37+
38+
if module_name in module_mtime_cache:
39+
last_mtime = module_mtime_cache[module_name]
40+
if current_mtime == last_mtime:
41+
return module
42+
43+
module_mtime_cache[module_name] = current_mtime
44+
module = importlib.reload(module)
45+
else:
46+
module = importlib.import_module(module_name)
47+
module_mtime_cache[module_name] = file_on_disk.stat().st_mtime
48+
2949
except ModuleNotFoundError as exc:
3050
raise ModuleImportError(message=f"{exc!s} ({module_path})") from exc
3151
except SyntaxError as exc:

infrahub_sdk/ctl/cli_commands.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,13 @@
3636
parse_cli_vars,
3737
)
3838
from ..ctl.validate import app as validate_app
39-
from ..exceptions import GraphQLError, InfrahubTransformNotFoundError
39+
from ..exceptions import GraphQLError, ModuleImportError
4040
from ..jinja2 import identify_faulty_jinja_code
4141
from ..schema import (
4242
InfrahubRepositoryConfig,
4343
MainSchemaTypes,
4444
SchemaRoot,
4545
)
46-
from ..transforms import get_transform_class_instance
4746
from ..utils import get_branch, write_to_file
4847
from ..yaml import SchemaFile
4948
from .exporter import dump
@@ -322,32 +321,22 @@ def transform(
322321
list_transforms(config=repository_config)
323322
return
324323

325-
# Load transform config
326-
try:
327-
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
328-
if not matched:
329-
raise ValueError(f"{transform_name} does not exist")
330-
except ValueError as exc:
331-
console.print(f"[red]Unable to find requested transform: {transform_name}")
332-
list_transforms(config=repository_config)
333-
raise typer.Exit(1) from exc
334-
335-
transform_config = matched[0]
324+
transform_config = repository_config.get_python_transform(name=transform_name)
336325

337326
# Get client
338327
client = initialize_client()
339328

340329
# Get python transform class instance
330+
331+
relative_path = str(transform_config.file_path.parent) if transform_config.file_path.parent != Path() else None
332+
341333
try:
342-
transform = get_transform_class_instance(
343-
transform_config=transform_config,
344-
branch=branch,
345-
client=client,
346-
)
347-
except InfrahubTransformNotFoundError as exc:
348-
console.print(f"Unable to load {transform_name} from python_transforms")
334+
transform_class = transform_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path)
335+
except ModuleImportError as exc:
336+
console.print(f"[red]{exc.message}")
349337
raise typer.Exit(1) from exc
350338

339+
transform = transform_class(client=client, branch=branch)
351340
# Get data
352341
query_str = repository_config.get_query(name=transform.query).load_query()
353342
data = asyncio.run(

infrahub_sdk/ctl/generator.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from pathlib import Path
22
from typing import Optional
33

4+
import typer
45
from rich.console import Console
56

67
from ..ctl import config
78
from ..ctl.client import initialize_client
89
from ..ctl.repository import get_repository_config
910
from ..ctl.utils import execute_graphql_query, parse_cli_vars
11+
from ..exceptions import ModuleImportError
1012
from ..node import InfrahubNode
1113
from ..schema import InfrahubRepositoryConfig
1214

@@ -25,17 +27,18 @@ async def run(
2527
list_generators(repository_config=repository_config)
2628
return
2729

28-
matched = [generator for generator in repository_config.generator_definitions if generator.name == generator_name] # pylint: disable=not-an-iterable
30+
generator_config = repository_config.get_generator_definition(name=generator_name)
2931

3032
console = Console()
3133

32-
if not matched:
33-
console.print(f"[red]Unable to find requested generator: {generator_name}")
34-
list_generators(repository_config=repository_config)
35-
return
34+
relative_path = str(generator_config.file_path.parent) if generator_config.file_path.parent != Path() else None
35+
36+
try:
37+
generator_class = generator_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path)
38+
except ModuleImportError as exc:
39+
console.print(f"[red]{exc.message}")
40+
raise typer.Exit(1) from exc
3641

37-
generator_config = matched[0]
38-
generator_class = generator_config.load_class()
3942
variables_dict = parse_cli_vars(variables)
4043

4144
param_key = list(generator_config.parameters.keys())

infrahub_sdk/ctl/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Error,
2121
GraphQLError,
2222
NodeNotFoundError,
23+
ResourceNotDefinedError,
2324
SchemaNotFoundError,
2425
ServerNotReachableError,
2526
ServerNotResponsiveError,
@@ -59,7 +60,7 @@ def handle_exception(exc: Exception, console: Console, exit_code: int) -> NoRetu
5960
if isinstance(exc, GraphQLError):
6061
print_graphql_errors(console=console, errors=exc.errors)
6162
raise typer.Exit(code=exit_code)
62-
if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError)):
63+
if isinstance(exc, (SchemaNotFoundError, NodeNotFoundError, ResourceNotDefinedError)):
6364
console.print(f"[red]Error: {exc!s}")
6465
raise typer.Exit(code=exit_code)
6566

infrahub_sdk/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ def __str__(self) -> str:
8787
"""
8888

8989

90+
class ResourceNotDefinedError(Error):
91+
"""Raised when trying to access a resource that hasn't been defined."""
92+
93+
def __init__(self, message: Optional[str] = None):
94+
self.message = message or "The requested resource was not found"
95+
super().__init__(self.message)
96+
97+
9098
class InfrahubCheckNotFoundError(Error):
9199
def __init__(self, name: str, message: Optional[str] = None):
92100
self.message = message or f"The requested InfrahubCheck '{name}' was not found."

infrahub_sdk/schema.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,16 @@
1111
from typing_extensions import TypeAlias
1212

1313
from ._importer import import_module
14-
from .exceptions import InvalidResponseError, ModuleImportError, SchemaNotFoundError, ValidationError
14+
from .exceptions import (
15+
InvalidResponseError,
16+
ModuleImportError,
17+
ResourceNotDefinedError,
18+
SchemaNotFoundError,
19+
ValidationError,
20+
)
1521
from .generator import InfrahubGenerator
1622
from .graphql import Mutation
23+
from .transforms import InfrahubTransform
1724
from .utils import duplicates
1825

1926
if TYPE_CHECKING:
@@ -120,6 +127,21 @@ class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement):
120127
file_path: Path = Field(..., description="The file within the repository with the transform code.")
121128
class_name: str = Field(default="Transform", description="The name of the transform class to run.")
122129

130+
def load_class(
131+
self, import_root: Optional[str] = None, relative_path: Optional[str] = None
132+
) -> type[InfrahubTransform]:
133+
module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path)
134+
135+
if self.class_name not in dir(module):
136+
raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module")
137+
138+
transform_class = getattr(module, self.class_name)
139+
140+
if not issubclass(transform_class, InfrahubTransform):
141+
raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform")
142+
143+
return transform_class
144+
123145

124146
class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement):
125147
model_config = ConfigDict(extra="forbid")
@@ -189,7 +211,7 @@ def _get_resource(
189211
for item in getattr(self, RESOURCE_MAP[resource_type]):
190212
if getattr(item, resource_field) == resource_id:
191213
return item
192-
raise KeyError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}")
214+
raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}")
193215

194216
def has_jinja2_transform(self, name: str) -> bool:
195217
return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig)

infrahub_sdk/transforms.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@
33
import asyncio
44
import importlib
55
import os
6-
import warnings
76
from abc import abstractmethod
87
from typing import TYPE_CHECKING, Any, Optional
98

109
from git import Repo
1110

12-
from . import InfrahubClient
13-
from .exceptions import InfrahubTransformNotFoundError
11+
from .exceptions import InfrahubTransformNotFoundError, UninitializedError
1412

1513
if TYPE_CHECKING:
1614
from pathlib import Path
1715

16+
from . import InfrahubClient
1817
from .schema import InfrahubPythonTransformConfig
1918

2019
INFRAHUB_TRANSFORM_VARIABLE_TO_IMPORT = "INFRAHUB_TRANSFORMS"
@@ -48,25 +47,10 @@ def __init__(
4847

4948
@property
5049
def client(self) -> InfrahubClient:
51-
if not self._client:
52-
self._client = InfrahubClient(address=self.server_url)
50+
if self._client:
51+
return self._client
5352

54-
return self._client
55-
56-
@classmethod
57-
async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform:
58-
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
59-
warnings.warn(
60-
f"{cls.__class__.__name__}.init has been deprecated and will be removed in the version after Infrahub SDK 1.0.0",
61-
DeprecationWarning,
62-
stacklevel=1,
63-
)
64-
if client:
65-
kwargs["client"] = client
66-
67-
item = cls(*args, **kwargs)
68-
69-
return item
53+
raise UninitializedError("The client has not been initialized")
7054

7155
@property
7256
def branch_name(self) -> str:

tests/unit/ctl/test_transform_app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_transform_not_exist_in_infrahub_yml(tags_transform_dir: str) -> None:
6262
transform_name = "not_existing_transform"
6363
with change_directory(tags_transform_dir):
6464
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
65-
assert f"Unable to find requested transform: {transform_name}" in output.stdout
65+
assert f"Unable to find '{transform_name}'" in strip_color(output.stdout)
6666
assert output.exit_code == 1
6767

6868
@staticmethod
@@ -77,7 +77,7 @@ def test_transform_python_file_not_defined(tags_transform_dir: str) -> None:
7777
transform_name = "tags_transform"
7878
with change_directory(tags_transform_dir):
7979
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
80-
assert f"Unable to load {transform_name} from python_transforms" in output.stdout
80+
assert "No module named 'tags_transform' (tags_transform.py)" in strip_color(output.stdout)
8181
assert output.exit_code == 1
8282

8383
@staticmethod
@@ -97,7 +97,7 @@ def test_transform_python_class_not_defined(tags_transform_dir: str) -> None:
9797
transform_name = "tags_transform"
9898
with change_directory(tags_transform_dir):
9999
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
100-
assert f"Unable to load {transform_name} from python_transforms" in output.stdout
100+
assert "The specified class TagsTransform was not found within the module" in output.stdout
101101
assert output.exit_code == 1
102102

103103
@staticmethod

0 commit comments

Comments
 (0)