Skip to content

Commit b04d4ac

Browse files
committed
Fix relative module imports
Fixes #81
1 parent 758f6ac commit b04d4ac

File tree

5 files changed

+37
-33
lines changed

5 files changed

+37
-33
lines changed

infrahub_sdk/_importer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ def import_module(
2525
module_name = relative_path.replace("/", ".") + f".{module_name}"
2626

2727
try:
28-
module = importlib.import_module(module_name)
28+
if module_name in sys.modules:
29+
module = sys.modules[module_name]
30+
module = importlib.reload(module)
31+
else:
32+
module = importlib.import_module(module_name)
33+
2934
except ModuleNotFoundError as exc:
3035
raise ModuleImportError(message=f"{exc!s} ({module_path})") from exc
3136
except SyntaxError as exc:

infrahub_sdk/ctl/cli_commands.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,13 @@
3636
parse_cli_vars,
3737
)
3838
from ..ctl.validate import app as validate_app
39-
from ..exceptions import GraphQLError, InfrahubTransformNotFoundError
39+
from ..exceptions import GraphQLError, ModuleImportError
4040
from ..jinja2 import identify_faulty_jinja_code
4141
from ..schema import (
4242
InfrahubRepositoryConfig,
4343
MainSchemaTypes,
4444
SchemaRoot,
4545
)
46-
from ..transforms import get_transform_class_instance
4746
from ..utils import get_branch, write_to_file
4847
from ..yaml import SchemaFile
4948
from .exporter import dump
@@ -338,16 +337,16 @@ def transform(
338337
client = initialize_client()
339338

340339
# Get python transform class instance
340+
341+
relative_path = str(transform_config.file_path.parent) if transform_config.file_path.parent != Path() else None
342+
341343
try:
342-
transform = get_transform_class_instance(
343-
transform_config=transform_config,
344-
branch=branch,
345-
client=client,
346-
)
347-
except InfrahubTransformNotFoundError as exc:
348-
console.print(f"Unable to load {transform_name} from python_transforms")
344+
transform_class = transform_config.load_class(import_root=str(Path.cwd()), relative_path=relative_path)
345+
except ModuleImportError as exc:
346+
console.print(f"[red]{exc.message}")
349347
raise typer.Exit(1) from exc
350348

349+
transform = transform_class(client=client, branch=branch)
351350
# Get data
352351
query_str = repository_config.get_query(name=transform.query).load_query()
353352
data = asyncio.run(

infrahub_sdk/schema.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .exceptions import InvalidResponseError, ModuleImportError, SchemaNotFoundError, ValidationError
1515
from .generator import InfrahubGenerator
1616
from .graphql import Mutation
17+
from .transforms import InfrahubTransform
1718
from .utils import duplicates
1819

1920
if TYPE_CHECKING:
@@ -120,6 +121,21 @@ class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement):
120121
file_path: Path = Field(..., description="The file within the repository with the transform code.")
121122
class_name: str = Field(default="Transform", description="The name of the transform class to run.")
122123

124+
def load_class(
125+
self, import_root: Optional[str] = None, relative_path: Optional[str] = None
126+
) -> type[InfrahubTransform]:
127+
module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path)
128+
129+
if self.class_name not in dir(module):
130+
raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module")
131+
132+
transform_class = getattr(module, self.class_name)
133+
134+
if not issubclass(transform_class, InfrahubTransform):
135+
raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform")
136+
137+
return transform_class
138+
123139

124140
class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement):
125141
model_config = ConfigDict(extra="forbid")

infrahub_sdk/transforms.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@
33
import asyncio
44
import importlib
55
import os
6-
import warnings
76
from abc import abstractmethod
87
from typing import TYPE_CHECKING, Any, Optional
98

109
from git import Repo
1110

12-
from . import InfrahubClient
13-
from .exceptions import InfrahubTransformNotFoundError
11+
from .exceptions import InfrahubTransformNotFoundError, UninitializedError
1412

1513
if TYPE_CHECKING:
1614
from pathlib import Path
1715

16+
from . import InfrahubClient
1817
from .schema import InfrahubPythonTransformConfig
1918

2019
INFRAHUB_TRANSFORM_VARIABLE_TO_IMPORT = "INFRAHUB_TRANSFORMS"
@@ -48,25 +47,10 @@ def __init__(
4847

4948
@property
5049
def client(self) -> InfrahubClient:
51-
if not self._client:
52-
self._client = InfrahubClient(address=self.server_url)
50+
if self._client:
51+
return self._client
5352

54-
return self._client
55-
56-
@classmethod
57-
async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform:
58-
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
59-
warnings.warn(
60-
f"{cls.__class__.__name__}.init has been deprecated and will be removed in the version after Infrahub SDK 1.0.0",
61-
DeprecationWarning,
62-
stacklevel=1,
63-
)
64-
if client:
65-
kwargs["client"] = client
66-
67-
item = cls(*args, **kwargs)
68-
69-
return item
53+
raise UninitializedError("The client has not been initialized")
7054

7155
@property
7256
def branch_name(self) -> str:

tests/unit/ctl/test_transform_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_transform_python_file_not_defined(tags_transform_dir: str) -> None:
7777
transform_name = "tags_transform"
7878
with change_directory(tags_transform_dir):
7979
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
80-
assert f"Unable to load {transform_name} from python_transforms" in output.stdout
80+
assert "No module named 'tags_transform' (tags_transform.py)" in strip_color(output.stdout)
8181
assert output.exit_code == 1
8282

8383
@staticmethod
@@ -97,7 +97,7 @@ def test_transform_python_class_not_defined(tags_transform_dir: str) -> None:
9797
transform_name = "tags_transform"
9898
with change_directory(tags_transform_dir):
9999
output = runner.invoke(app, ["transform", transform_name, "tag=red"])
100-
assert f"Unable to load {transform_name} from python_transforms" in output.stdout
100+
assert "The specified class TagsTransform was not found within the module" in output.stdout
101101
assert output.exit_code == 1
102102

103103
@staticmethod

0 commit comments

Comments
 (0)