Skip to content

Commit 3f7f947

Browse files
committed
Add ability to use convert_query_response with Python Transforms
Fixes #281
1 parent 23d9f1a commit 3f7f947

File tree

17 files changed

+439
-101
lines changed

17 files changed

+439
-101
lines changed

changelog/281.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added ability to convert the query response to InfrahubNode objects when using Python Transforms in the same way you can with Generators.

infrahub_sdk/client.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,6 @@ def _build_ip_prefix_allocation_query(
271271
input_data={"data": input_data},
272272
)
273273

274-
def _clone_config(self, branch: str | None = None) -> Config:
275-
config = copy.deepcopy(self.config)
276-
config.default_branch = branch or config.default_branch
277-
return config
278-
279274

280275
class InfrahubClient(BaseClient):
281276
"""GraphQL Client to interact with Infrahub."""
@@ -854,7 +849,7 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]:
854849

855850
def clone(self, branch: str | None = None) -> InfrahubClient:
856851
"""Return a cloned version of the client using the same configuration"""
857-
return InfrahubClient(config=self._clone_config(branch=branch))
852+
return InfrahubClient(config=self.config.clone(branch=branch))
858853

859854
async def execute_graphql(
860855
self,
@@ -1598,7 +1593,7 @@ def delete(self, kind: str | type[SchemaTypeSync], id: str, branch: str | None =
15981593

15991594
def clone(self, branch: str | None = None) -> InfrahubClientSync:
16001595
"""Return a cloned version of the client using the same configuration"""
1601-
return InfrahubClientSync(config=self._clone_config(branch=branch))
1596+
return InfrahubClientSync(config=self.config.clone(branch=branch))
16021597

16031598
def execute_graphql(
16041599
self,

infrahub_sdk/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from copy import deepcopy
34
from typing import Any
45

56
from pydantic import Field, field_validator, model_validator
@@ -158,3 +159,19 @@ def set_custom_recorder(cls, values: dict[str, Any]) -> dict[str, Any]:
158159
elif values.get("recorder") == RecorderType.JSON and "custom_recorder" not in values:
159160
values["custom_recorder"] = JSONRecorder()
160161
return values
162+
163+
def clone(self, branch: str | None = None) -> Config:
164+
config: dict[str, Any] = {
165+
"default_branch": branch or self.default_branch,
166+
"recorder": self.recorder,
167+
"custom_recorder": self.custom_recorder,
168+
"requester": self.requester,
169+
"sync_requester": self.sync_requester,
170+
"log": self.log,
171+
}
172+
covered_keys = list(config.keys())
173+
for field in self.model_fields.keys():
174+
if field not in covered_keys:
175+
config[field] = deepcopy(getattr(self, field))
176+
177+
return Config(**config)

infrahub_sdk/ctl/cli_commands.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from ..ctl.validate import app as validate_app
4343
from ..exceptions import GraphQLError, ModuleImportError
44+
from ..node import InfrahubNode
4445
from ..protocols_generator.generator import CodeGenerator
4546
from ..schema import MainSchemaTypesAll, SchemaRoot
4647
from ..template import Jinja2Template
@@ -330,7 +331,12 @@ def transform(
330331
console.print(f"[red]{exc.message}")
331332
raise typer.Exit(1) from exc
332333

333-
transform = transform_class(client=client, branch=branch)
334+
transform = transform_class(
335+
client=client,
336+
branch=branch,
337+
infrahub_node=InfrahubNode,
338+
convert_query_response=transform_config.convert_query_response,
339+
)
334340
# Get data
335341
query_str = repository_config.get_query(name=transform.query).load_query()
336342
data = asyncio.run(

infrahub_sdk/ctl/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def run(
6262
generator = generator_class(
6363
query=generator_config.query,
6464
client=client,
65-
branch=branch,
65+
branch=branch or "",
6666
params=variables_dict,
6767
convert_query_response=generator_config.convert_query_response,
6868
infrahub_node=InfrahubNode,
@@ -91,7 +91,7 @@ async def run(
9191
generator = generator_class(
9292
query=generator_config.query,
9393
client=client,
94-
branch=branch,
94+
branch=branch or "",
9595
params=params,
9696
convert_query_response=generator_config.convert_query_response,
9797
infrahub_node=InfrahubNode,

infrahub_sdk/generator.py

Lines changed: 12 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
from __future__ import annotations
22

33
import logging
4-
import os
54
from abc import abstractmethod
65
from typing import TYPE_CHECKING
76

8-
from infrahub_sdk.repository import GitRepoManager
9-
107
from .exceptions import UninitializedError
8+
from .operation import InfrahubOperation
119

1210
if TYPE_CHECKING:
1311
from .client import InfrahubClient
1412
from .context import RequestContext
1513
from .node import InfrahubNode
16-
from .store import NodeStore
1714

1815

19-
class InfrahubGenerator:
16+
class InfrahubGenerator(InfrahubOperation):
2017
"""Infrahub Generator class"""
2118

2219
def __init__(
2320
self,
2421
query: str,
2522
client: InfrahubClient,
2623
infrahub_node: type[InfrahubNode],
27-
branch: str | None = None,
24+
branch: str = "",
2825
root_directory: str = "",
2926
generator_instance: str = "",
3027
params: dict | None = None,
@@ -33,35 +30,21 @@ def __init__(
3330
request_context: RequestContext | None = None,
3431
) -> None:
3532
self.query = query
36-
self.branch = branch
37-
self.git: GitRepoManager | None = None
33+
34+
super().__init__(
35+
client=client,
36+
infrahub_node=infrahub_node,
37+
convert_query_response=convert_query_response,
38+
branch=branch,
39+
root_directory=root_directory,
40+
)
41+
3842
self.params = params or {}
39-
self.root_directory = root_directory or os.getcwd()
4043
self.generator_instance = generator_instance
41-
self._init_client = client.clone(branch=self.branch_name)
4244
self._client: InfrahubClient | None = None
43-
self._nodes: list[InfrahubNode] = []
44-
self._related_nodes: list[InfrahubNode] = []
45-
self.infrahub_node = infrahub_node
46-
self.convert_query_response = convert_query_response
4745
self.logger = logger if logger else logging.getLogger("infrahub.tasks")
4846
self.request_context = request_context
4947

50-
@property
51-
def store(self) -> NodeStore:
52-
"""The store will be populated with nodes based on the query during the collection of data if activated"""
53-
return self._init_client.store
54-
55-
@property
56-
def nodes(self) -> list[InfrahubNode]:
57-
"""Returns nodes collected and parsed during the data collection process if this feature is enables"""
58-
return self._nodes
59-
60-
@property
61-
def related_nodes(self) -> list[InfrahubNode]:
62-
"""Returns nodes collected and parsed during the data collection process if this feature is enables"""
63-
return self._related_nodes
64-
6548
@property
6649
def subscribers(self) -> list[str] | None:
6750
if self.generator_instance:
@@ -78,20 +61,6 @@ def client(self) -> InfrahubClient:
7861
def client(self, value: InfrahubClient) -> None:
7962
self._client = value
8063

81-
@property
82-
def branch_name(self) -> str:
83-
"""Return the name of the current git branch."""
84-
85-
if self.branch:
86-
return self.branch
87-
88-
if not self.git:
89-
self.git = GitRepoManager(self.root_directory)
90-
91-
self.branch = str(self.git.active_branch)
92-
93-
return self.branch
94-
9564
async def collect_data(self) -> dict:
9665
"""Query the result of the GraphQL Query defined in self.query and return the result"""
9766

@@ -117,27 +86,6 @@ async def run(self, identifier: str, data: dict | None = None) -> None:
11786
) as self.client:
11887
await self.generate(data=unpacked)
11988

120-
async def process_nodes(self, data: dict) -> None:
121-
if not self.convert_query_response:
122-
return
123-
124-
await self._init_client.schema.all(branch=self.branch_name)
125-
126-
for kind in data:
127-
if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
128-
for result in data[kind].get("edges", []):
129-
node = await self.infrahub_node.from_graphql(
130-
client=self._init_client, branch=self.branch_name, data=result
131-
)
132-
self._nodes.append(node)
133-
await node._process_relationships(
134-
node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
135-
)
136-
137-
for node in self._nodes + self._related_nodes:
138-
if node.id:
139-
self._init_client.store.set(node=node)
140-
14189
@abstractmethod
14290
async def generate(self, data: dict) -> None:
14391
"""Code to run the generator

infrahub_sdk/operation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import TYPE_CHECKING
5+
6+
from .repository import GitRepoManager
7+
8+
if TYPE_CHECKING:
9+
from . import InfrahubClient
10+
from .node import InfrahubNode
11+
from .store import NodeStore
12+
13+
14+
class InfrahubOperation:
15+
def __init__(
16+
self,
17+
client: InfrahubClient,
18+
infrahub_node: type[InfrahubNode],
19+
convert_query_response: bool,
20+
branch: str,
21+
root_directory: str,
22+
):
23+
self.branch = branch
24+
self.convert_query_response = convert_query_response
25+
self.root_directory = root_directory or os.getcwd()
26+
self.infrahub_node = infrahub_node
27+
self._nodes: list[InfrahubNode] = []
28+
self._related_nodes: list[InfrahubNode] = []
29+
self._init_client = client.clone(branch=self.branch_name)
30+
self.git: GitRepoManager | None = None
31+
32+
@property
33+
def branch_name(self) -> str:
34+
"""Return the name of the current git branch."""
35+
36+
if self.branch:
37+
return self.branch
38+
39+
if not hasattr(self, "git") or not self.git:
40+
self.git = GitRepoManager(self.root_directory)
41+
42+
self.branch = str(self.git.active_branch)
43+
44+
return self.branch
45+
46+
@property
47+
def store(self) -> NodeStore:
48+
"""The store will be populated with nodes based on the query during the collection of data if activated"""
49+
return self._init_client.store
50+
51+
@property
52+
def nodes(self) -> list[InfrahubNode]:
53+
"""Returns nodes collected and parsed during the data collection process if this feature is enables"""
54+
return self._nodes
55+
56+
@property
57+
def related_nodes(self) -> list[InfrahubNode]:
58+
"""Returns nodes collected and parsed during the data collection process if this feature is enables"""
59+
return self._related_nodes
60+
61+
async def process_nodes(self, data: dict) -> None:
62+
if not self.convert_query_response:
63+
return
64+
65+
await self._init_client.schema.all(branch=self.branch_name)
66+
67+
for kind in data:
68+
if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
69+
for result in data[kind].get("edges", []):
70+
node = await self.infrahub_node.from_graphql(
71+
client=self._init_client, branch=self.branch_name, data=result
72+
)
73+
self._nodes.append(node)
74+
await node._process_relationships(
75+
node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
76+
)
77+
78+
for node in self._nodes + self._related_nodes:
79+
if node.id:
80+
self._init_client.store.set(node=node)

infrahub_sdk/protocols.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ class CoreMenu(CoreNode):
154154
children: RelationshipManager
155155

156156

157+
class CoreObjectComponentTemplate(CoreNode):
158+
template_name: String
159+
160+
157161
class CoreObjectTemplate(CoreNode):
158162
template_name: String
159163

@@ -205,6 +209,7 @@ class CoreWebhook(CoreNode):
205209
name: String
206210
event_type: Enum
207211
branch_scope: Dropdown
212+
node_kind: StringOptional
208213
description: StringOptional
209214
url: URL
210215
validate_certificates: BooleanOptional
@@ -479,6 +484,7 @@ class CoreTransformJinja2(CoreTransformation):
479484
class CoreTransformPython(CoreTransformation):
480485
file_path: String
481486
class_name: String
487+
convert_query_response: BooleanOptional
482488

483489

484490
class CoreUserValidator(CoreValidator):
@@ -625,6 +631,10 @@ class CoreMenuSync(CoreNodeSync):
625631
children: RelationshipManagerSync
626632

627633

634+
class CoreObjectComponentTemplateSync(CoreNodeSync):
635+
template_name: String
636+
637+
628638
class CoreObjectTemplateSync(CoreNodeSync):
629639
template_name: String
630640

@@ -676,6 +686,7 @@ class CoreWebhookSync(CoreNodeSync):
676686
name: String
677687
event_type: Enum
678688
branch_scope: Dropdown
689+
node_kind: StringOptional
679690
description: StringOptional
680691
url: URL
681692
validate_certificates: BooleanOptional
@@ -950,6 +961,7 @@ class CoreTransformJinja2Sync(CoreTransformationSync):
950961
class CoreTransformPythonSync(CoreTransformationSync):
951962
file_path: String
952963
class_name: String
964+
convert_query_response: BooleanOptional
953965

954966

955967
class CoreUserValidatorSync(CoreValidatorSync):

infrahub_sdk/schema/repository.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement):
117117
name: str = Field(..., description="The name of the Transform")
118118
file_path: Path = Field(..., description="The file within the repository with the transform code.")
119119
class_name: str = Field(default="Transform", description="The name of the transform class to run.")
120+
convert_query_response: bool = Field(
121+
default=False,
122+
description="Decide if the transform should convert the result of the GraphQL query to SDK InfrahubNode objects.",
123+
)
120124

121125
def load_class(self, import_root: str | None = None, relative_path: str | None = None) -> type[InfrahubTransform]:
122126
module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path)

0 commit comments

Comments
 (0)