Skip to content

Commit 3af81a9

Browse files
committed
Fix pytest plugin relative import
1 parent fc61aff commit 3af81a9

File tree

7 files changed

+53
-77
lines changed

7 files changed

+53
-77
lines changed

changelog/166.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix relative imports for the pytest plugin, note that the relative imports can't be at the top level of the repository alongside .infrahub.yml. They have to be located within a subfolder.

infrahub_sdk/_importer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
def import_module(
1717
module_path: Path, import_root: Optional[str] = None, relative_path: Optional[str] = None
1818
) -> ModuleType:
19+
"""Imports a python module.
20+
21+
Attributes:
22+
module_path (Path): Absolute path of the module to import.
23+
import_root (Optional[str]): Absolute string path to the folder.
24+
relative_path (Optional[str]): Relative string path between module_path and import_root.
25+
TODO Compute `relative_path` here instead of having it as a parameter?
26+
"""
27+
1928
import_root = import_root or str(module_path.parent)
2029

2130
file_on_disk = module_path
@@ -35,6 +44,8 @@ def import_module(
3544
module_name = relative_path.replace("/", ".") + f".{module_name}"
3645

3746
try:
47+
# We hold a mapping of imported modules. If a module is already loaded and does not have recent changes,
48+
# then we do not reload/import this module.
3849
if module_name in sys.modules:
3950
module = sys.modules[module_name]
4051
current_mtime = file_on_disk.stat().st_mtime

infrahub_sdk/checks.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,10 @@
1111
from git.repo import Repo
1212
from pydantic import BaseModel, Field
1313

14-
from .exceptions import InfrahubCheckNotFoundError, UninitializedError
14+
from .exceptions import UninitializedError
1515

1616
if TYPE_CHECKING:
17-
from pathlib import Path
18-
1917
from . import InfrahubClient
20-
from .schema.repository import InfrahubCheckDefinitionConfig
2118

2219
INFRAHUB_CHECK_VARIABLE_TO_IMPORT = "INFRAHUB_CHECKS"
2320

@@ -176,27 +173,3 @@ async def run(self, data: Optional[dict] = None) -> bool:
176173
self.log_info("Check succesfully completed")
177174

178175
return self.passed
179-
180-
181-
def get_check_class_instance(
182-
check_config: InfrahubCheckDefinitionConfig, search_path: Optional[Path] = None
183-
) -> InfrahubCheck:
184-
if check_config.file_path.is_absolute() or search_path is None:
185-
search_location = check_config.file_path
186-
else:
187-
search_location = search_path / check_config.file_path
188-
189-
try:
190-
spec = importlib.util.spec_from_file_location(check_config.class_name, search_location)
191-
module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
192-
spec.loader.exec_module(module) # type: ignore[union-attr]
193-
194-
# Get the specified class from the module
195-
check_class = getattr(module, check_config.class_name)
196-
197-
# Create an instance of the class
198-
check_instance = check_class()
199-
except (FileNotFoundError, AttributeError) as exc:
200-
raise InfrahubCheckNotFoundError(name=check_config.name) from exc
201-
202-
return check_instance

infrahub_sdk/pytest_plugin/items/check.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ujson
77
from httpx import HTTPStatusError
88

9-
from ...checks import get_check_class_instance
9+
from ...utils import get_check_or_transform_class
1010
from ..exceptions import CheckDefinitionError, CheckResultError
1111
from ..models import InfrahubTestExpectedResult
1212
from .base import InfrahubItem
@@ -33,10 +33,11 @@ def __init__(
3333
self.check_instance: InfrahubCheck
3434

3535
def instantiate_check(self) -> None:
36-
self.check_instance = get_check_class_instance(
37-
check_config=self.resource_config, # type: ignore[arg-type]
36+
check_class = get_check_or_transform_class(
37+
config=self.resource_config, # type: ignore[arg-type]
3838
search_path=self.session.infrahub_config_path.parent, # type: ignore[attr-defined]
3939
)
40+
self.check_instance = check_class()
4041

4142
def run_check(self, variables: dict[str, Any]) -> Any:
4243
self.instantiate_check()

infrahub_sdk/pytest_plugin/items/python_transform.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ujson
77
from httpx import HTTPStatusError
88

9-
from ...transforms import get_transform_class_instance
9+
from ...utils import get_check_or_transform_class
1010
from ..exceptions import OutputMatchError, PythonTransformDefinitionError
1111
from ..models import InfrahubTestExpectedResult
1212
from .base import InfrahubItem
@@ -33,10 +33,11 @@ def __init__(
3333
self.transform_instance: InfrahubTransform
3434

3535
def instantiate_transform(self) -> None:
36-
self.transform_instance = get_transform_class_instance(
37-
transform_config=self.resource_config, # type: ignore[arg-type]
36+
transform_class = get_check_or_transform_class(
37+
config=self.resource_config, # type: ignore[arg-type]
3838
search_path=self.session.infrahub_config_path.parent, # type: ignore[attr-defined]
3939
)
40+
self.transform_instance = transform_class(branch="", client=None)
4041

4142
def run_transform(self, variables: dict[str, Any]) -> Any:
4243
self.instantiate_transform()

infrahub_sdk/transforms.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import importlib
54
import os
65
from abc import abstractmethod
76
from typing import TYPE_CHECKING, Any, Optional
87

98
from git import Repo
109

11-
from .exceptions import InfrahubTransformNotFoundError, UninitializedError
10+
from .exceptions import UninitializedError
1211

1312
if TYPE_CHECKING:
14-
from pathlib import Path
15-
1613
from . import InfrahubClient
17-
from .schema.repository import InfrahubPythonTransformConfig
1814

1915
INFRAHUB_TRANSFORM_VARIABLE_TO_IMPORT = "INFRAHUB_TRANSFORMS"
2016

@@ -95,40 +91,3 @@ async def run(self, data: Optional[dict] = None) -> Any:
9591
return await self.transform(data=unpacked)
9692

9793
return self.transform(data=unpacked)
98-
99-
100-
def get_transform_class_instance(
101-
transform_config: InfrahubPythonTransformConfig,
102-
search_path: Optional[Path] = None,
103-
branch: str = "",
104-
client: Optional[InfrahubClient] = None,
105-
) -> InfrahubTransform:
106-
"""Gets an instance of the InfrahubTransform class.
107-
108-
Args:
109-
transform_config: A config object with information required to find and load the transform.
110-
search_path: The path in which to search for a python file containing the transform. The current directory is
111-
assumed if not speicifed.
112-
branch: Infrahub branch which will be targeted in graphql query used to acquire data for transformation.
113-
client: InfrahubClient used to interact with infrahub API.
114-
"""
115-
if transform_config.file_path.is_absolute() or search_path is None:
116-
search_location = transform_config.file_path
117-
else:
118-
search_location = search_path / transform_config.file_path
119-
120-
try:
121-
spec = importlib.util.spec_from_file_location(transform_config.class_name, search_location)
122-
module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
123-
spec.loader.exec_module(module) # type: ignore[union-attr]
124-
125-
# Get the specified class from the module
126-
transform_class = getattr(module, transform_config.class_name)
127-
128-
# Create an instance of the class
129-
transform_instance = transform_class(branch=branch, client=client)
130-
131-
except (FileNotFoundError, AttributeError) as exc:
132-
raise InfrahubTransformNotFoundError(name=transform_config.name) from exc
133-
134-
return transform_instance

infrahub_sdk/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import hashlib
4+
import importlib
45
import json
56
from itertools import groupby
67
from pathlib import Path
@@ -16,7 +17,8 @@
1617
SelectionSetNode,
1718
)
1819

19-
from .exceptions import JsonDecodeError
20+
from .exceptions import InfrahubCheckNotFoundError, InfrahubTransformNotFoundError, JsonDecodeError
21+
from .schema.repository import InfrahubCheckDefinitionConfig, InfrahubPythonTransformConfig
2022

2123
if TYPE_CHECKING:
2224
from graphql import GraphQLResolveInfo
@@ -335,3 +337,31 @@ def write_to_file(path: Path, value: Any) -> bool:
335337
written = path.write_text(to_write)
336338

337339
return written is not None
340+
341+
342+
def get_check_or_transform_class(
343+
config: InfrahubCheckDefinitionConfig | InfrahubPythonTransformConfig, search_path: Optional[Path] = None
344+
) -> type:
345+
if config.file_path.is_absolute() or search_path is None:
346+
search_location = config.file_path
347+
else:
348+
search_location = search_path / config.file_path
349+
350+
try:
351+
spec = importlib.util.spec_from_file_location(config.class_name, search_location)
352+
module = importlib.util.module_from_spec(spec) # type: ignore[arg-type]
353+
354+
# Set base module for relative import. This is useful when pytest plugin runs through proposed change pipeline,
355+
# as pytest is invoked outside of the imported repository.
356+
# Note that using __package__ logs a `DeprecationWarning: __package__ != __spec__.parent`
357+
module.__package__ = str(search_location.parent.name)
358+
359+
spec.loader.exec_module(module) # type: ignore[union-attr]
360+
361+
# Get the specified class from the module
362+
return getattr(module, config.class_name)
363+
364+
except (FileNotFoundError, AttributeError) as exc:
365+
if isinstance(config, InfrahubPythonTransformConfig):
366+
raise InfrahubTransformNotFoundError(name=config.name) from exc
367+
raise InfrahubCheckNotFoundError(name=config.name) from exc

0 commit comments

Comments
 (0)