|
11 | 11 | from typing_extensions import TypeAlias |
12 | 12 |
|
13 | 13 | from ._importer import import_module |
14 | | -from .exceptions import InvalidResponseError, ModuleImportError, SchemaNotFoundError, ValidationError |
| 14 | +from .exceptions import ( |
| 15 | + InvalidResponseError, |
| 16 | + ModuleImportError, |
| 17 | + ResourceNotDefinedError, |
| 18 | + SchemaNotFoundError, |
| 19 | + ValidationError, |
| 20 | +) |
15 | 21 | from .generator import InfrahubGenerator |
16 | 22 | from .graphql import Mutation |
| 23 | +from .transforms import InfrahubTransform |
17 | 24 | from .utils import duplicates |
18 | 25 |
|
19 | 26 | if TYPE_CHECKING: |
@@ -120,6 +127,21 @@ class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement): |
120 | 127 | file_path: Path = Field(..., description="The file within the repository with the transform code.") |
121 | 128 | class_name: str = Field(default="Transform", description="The name of the transform class to run.") |
122 | 129 |
|
| 130 | + def load_class( |
| 131 | + self, import_root: Optional[str] = None, relative_path: Optional[str] = None |
| 132 | + ) -> type[InfrahubTransform]: |
| 133 | + module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path) |
| 134 | + |
| 135 | + if self.class_name not in dir(module): |
| 136 | + raise ModuleImportError(message=f"The specified class {self.class_name} was not found within the module") |
| 137 | + |
| 138 | + transform_class = getattr(module, self.class_name) |
| 139 | + |
| 140 | + if not issubclass(transform_class, InfrahubTransform): |
| 141 | + raise ModuleImportError(message=f"The specified class {self.class_name} is not an Infrahub Transform") |
| 142 | + |
| 143 | + return transform_class |
| 144 | + |
123 | 145 |
|
124 | 146 | class InfrahubRepositoryGraphQLConfig(InfrahubRepositoryConfigElement): |
125 | 147 | model_config = ConfigDict(extra="forbid") |
@@ -189,7 +211,7 @@ def _get_resource( |
189 | 211 | for item in getattr(self, RESOURCE_MAP[resource_type]): |
190 | 212 | if getattr(item, resource_field) == resource_id: |
191 | 213 | return item |
192 | | - raise KeyError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") |
| 214 | + raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}") |
193 | 215 |
|
194 | 216 | def has_jinja2_transform(self, name: str) -> bool: |
195 | 217 | return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig) |
|
0 commit comments