Skip to content

Commit 22e5c66

Browse files
authored
Merge pull request #6348 from opsmill/pog-refactor-computed-attribute-filter-query
Refactor computed attribute update_computed_attribute_value_jinja2
2 parents 1494d53 + 658af4e commit 22e5c66

File tree

3 files changed

+116
-54
lines changed

3 files changed

+116
-54
lines changed

backend/infrahub/computed_attribute/models.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
from collections import defaultdict
44
from dataclasses import dataclass, field
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Any
66

7+
from infrahub_sdk.graphql import Query
78
from prefect.events.schemas.automations import Automation # noqa: TC002
89
from pydantic import BaseModel, ConfigDict, Field, computed_field
910
from typing_extensions import Self
1011

1112
from infrahub.core import registry
13+
from infrahub.core.constants import RelationshipCardinality
14+
from infrahub.core.schema import AttributeSchema, NodeSchema # noqa: TC001
1215
from infrahub.core.schema.schema_branch_computed import ( # noqa: TC001
1316
ComputedAttributeTarget,
1417
ComputedAttributeTriggerNode,
@@ -309,3 +312,80 @@ def from_object(
309312
)
310313

311314
return definition
315+
316+
317+
class ComputedAttrJinja2GraphQLResponse(BaseModel):
318+
node_id: str
319+
computed_attribute_value: str | None
320+
variables: dict[str, Any] = Field(default_factory=dict)
321+
322+
323+
class ComputedAttrJinja2GraphQL(BaseModel):
324+
node_schema: NodeSchema = Field(..., description="The node kind where the computed attribute is defined")
325+
attribute_schema: AttributeSchema = Field(..., description="The computed attribute")
326+
variables: list[str] = Field(..., description="The list of variable names used within the computed attribute")
327+
328+
def render_graphql_query(self, query_filter: str, filter_id: str) -> str:
329+
query_fields = self.query_fields
330+
query_fields["id"] = None
331+
query_fields[self.attribute_schema.name] = {"value": None}
332+
query = Query(
333+
name="ComputedAttributeFilter",
334+
query={
335+
self.node_schema.kind: {
336+
"@filters": {query_filter: filter_id},
337+
"edges": {"node": query_fields},
338+
}
339+
},
340+
)
341+
342+
return query.render()
343+
344+
@property
345+
def query_fields(self) -> dict[str, Any]:
346+
output: dict[str, Any] = {}
347+
for variable in self.variables:
348+
field_name, remainder = variable.split("__", maxsplit=1)
349+
if field_name in self.node_schema.attribute_names:
350+
output[field_name] = {remainder: None}
351+
elif field_name in self.node_schema.relationship_names:
352+
related_attribute, related_value = remainder.split("__", maxsplit=1)
353+
relationship = self.node_schema.get_relationship(name=field_name)
354+
if relationship.cardinality == RelationshipCardinality.ONE:
355+
if field_name not in output:
356+
output[field_name] = {"node": {}}
357+
output[field_name]["node"][related_attribute] = {related_value: None}
358+
return output
359+
360+
def parse_response(self, response: dict[str, Any]) -> list[ComputedAttrJinja2GraphQLResponse]:
361+
rendered_response: list[ComputedAttrJinja2GraphQLResponse] = []
362+
if kind_payload := response.get(self.node_schema.kind):
363+
edges = kind_payload.get("edges", [])
364+
for node in edges:
365+
if node_response := self.to_node_response(node_dict=node):
366+
rendered_response.append(node_response)
367+
return rendered_response
368+
369+
def to_node_response(self, node_dict: dict[str, Any]) -> ComputedAttrJinja2GraphQLResponse | None:
370+
if node := node_dict.get("node"):
371+
node_id = node.get("id")
372+
else:
373+
return None
374+
375+
computed_attribute = node.get(self.attribute_schema.name, {}).get("value")
376+
response = ComputedAttrJinja2GraphQLResponse(node_id=node_id, computed_attribute_value=computed_attribute)
377+
for variable in self.variables:
378+
field_name, remainder = variable.split("__", maxsplit=1)
379+
response.variables[variable] = None
380+
if field_content := node.get(field_name):
381+
if field_name in self.node_schema.attribute_names:
382+
response.variables[variable] = field_content.get(remainder)
383+
elif field_name in self.node_schema.relationship_names:
384+
relationship = self.node_schema.get_relationship(name=field_name)
385+
if relationship.cardinality == RelationshipCardinality.ONE:
386+
related_attribute, related_value = remainder.split("__", maxsplit=1)
387+
node_content = field_content.get("node") or {}
388+
related_attribute_content = node_content.get(related_attribute) or {}
389+
response.variables[variable] = related_attribute_content.get(related_value)
390+
391+
return response

backend/infrahub/computed_attribute/tasks.py

Lines changed: 34 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5-
from infrahub_sdk.protocols import (
6-
CoreNode, # noqa: TC002
7-
CoreTransformPython,
8-
)
5+
from infrahub_sdk.protocols import CoreTransformPython
96
from infrahub_sdk.template import Jinja2Template
107
from prefect import flow
118
from prefect.client.orchestration import get_client
@@ -28,9 +25,7 @@
2825
from infrahub.workflows.utils import add_tags, wait_for_schema_to_converge
2926

3027
from .gather import gather_trigger_computed_attribute_jinja2, gather_trigger_computed_attribute_python
31-
from .models import (
32-
PythonTransformTarget,
33-
)
28+
from .models import ComputedAttrJinja2GraphQL, ComputedAttrJinja2GraphQLResponse, PythonTransformTarget
3429

3530
if TYPE_CHECKING:
3631
from infrahub.core.schema.computed_attribute import ComputedAttribute
@@ -167,49 +162,33 @@ async def trigger_update_python_computed_attributes(
167162
flow_run_name="Update value for computed attribute {attribute_name}",
168163
)
169164
async def update_computed_attribute_value_jinja2(
170-
branch_name: str, obj: CoreNode, attribute_name: str, template_value: str, service: InfrahubServices
165+
branch_name: str,
166+
obj: ComputedAttrJinja2GraphQLResponse,
167+
node_kind: str,
168+
attribute_name: str,
169+
template: Jinja2Template,
170+
service: InfrahubServices,
171171
) -> None:
172172
log = get_run_logger()
173173

174-
await add_tags(branches=[branch_name], nodes=[obj.id], db_change=True)
175-
176-
jinja_template = Jinja2Template(template=template_value)
177-
variables = {}
178-
for variable in jinja_template.get_variables():
179-
components = variable.split("__")
180-
if len(components) == 2:
181-
property_name = components[0]
182-
property_value = components[1]
183-
attribute_property = getattr(obj, property_name)
184-
variables[variable] = getattr(attribute_property, property_value)
185-
elif len(components) == 3:
186-
relationship_name = components[0]
187-
property_name = components[1]
188-
property_value = components[2]
189-
relationship = getattr(obj, relationship_name)
190-
try:
191-
attribute_property = getattr(relationship.peer, property_name)
192-
variables[variable] = getattr(attribute_property, property_value)
193-
except ValueError:
194-
variables[variable] = ""
195-
196-
value = await jinja_template.render(variables=variables)
197-
existing_value = getattr(obj, attribute_name).value
198-
if value == existing_value:
174+
await add_tags(branches=[branch_name], nodes=[obj.node_id], db_change=True)
175+
176+
value = await template.render(variables=obj.variables)
177+
if value == obj.computed_attribute_value:
199178
log.debug(f"Ignoring to update {obj} with existing value on {attribute_name}={value}")
200179
return
201180

202181
await service.client.execute_graphql(
203182
query=UPDATE_ATTRIBUTE,
204183
variables={
205-
"id": obj.id,
206-
"kind": obj.get_kind(),
184+
"id": obj.node_id,
185+
"kind": node_kind,
207186
"attribute": attribute_name,
208187
"value": value,
209188
},
210189
branch_name=branch_name,
211190
)
212-
log.info(f"Updating computed attribute {obj.get_kind()}.{attribute_name}='{value}' ({obj.id})")
191+
log.info(f"Updating computed attribute {node_kind}.{attribute_name}='{value}' ({obj.node_id})")
213192

214193

215194
@flow(
@@ -235,41 +214,43 @@ async def process_jinja2(
235214
branch_name if branch_name in registry.get_altered_schema_branches() else registry.default_branch
236215
)
237216
schema_branch = registry.schema.get_schema_branch(name=target_branch_schema)
238-
await service.client.schema.all(branch=branch_name, refresh=True, schema_hash=schema_branch.get_hash())
239-
217+
node_schema = schema_branch.get_node(name=computed_attribute_kind, duplicate=False)
240218
computed_macros = [
241219
attrib
242220
for attrib in schema_branch.computed_attributes.get_impacted_jinja2_targets(kind=node_kind, updates=updates)
243221
if attrib.kind == computed_attribute_kind and attrib.attribute.name == computed_attribute_name
244222
]
245223
for computed_macro in computed_macros:
246-
found: list[CoreNode] = []
224+
found: list[ComputedAttrJinja2GraphQLResponse] = []
225+
template_string = "n/a"
226+
if computed_macro.attribute.computed_attribute and computed_macro.attribute.computed_attribute.jinja2_template:
227+
template_string = computed_macro.attribute.computed_attribute.jinja2_template
228+
229+
jinja_template = Jinja2Template(template=template_string)
230+
variables = jinja_template.get_variables()
231+
232+
attribute_graphql = ComputedAttrJinja2GraphQL(
233+
node_schema=node_schema, attribute_schema=computed_macro.attribute, variables=variables
234+
)
235+
247236
for id_filter in computed_macro.node_filters:
248-
filters = {id_filter: object_id}
249-
nodes: list[CoreNode] = await service.client.filters(
250-
kind=computed_macro.kind,
251-
branch=branch_name,
252-
prefetch_relationships=True,
253-
populate_store=True,
254-
**filters,
255-
)
256-
found.extend(nodes)
237+
query = attribute_graphql.render_graphql_query(query_filter=id_filter, filter_id=object_id)
238+
response = await service.client.execute_graphql(query=query, branch_name=branch_name)
239+
output = attribute_graphql.parse_response(response=response)
240+
found.extend(output)
257241

258242
if not found:
259243
log.debug("No nodes found that requires updates")
260244

261-
template_string = "n/a"
262-
if computed_macro.attribute.computed_attribute and computed_macro.attribute.computed_attribute.jinja2_template:
263-
template_string = computed_macro.attribute.computed_attribute.jinja2_template
264-
265245
batch = await service.client.create_batch()
266246
for node in found:
267247
batch.add(
268248
task=update_computed_attribute_value_jinja2,
269249
branch_name=branch_name,
270250
obj=node,
251+
node_kind=node_schema.kind,
271252
attribute_name=computed_macro.attribute.name,
272-
template_value=template_string,
253+
template=jinja_template,
273254
service=service,
274255
)
275256

changelog/6351.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve performance of computed attributes when updating a large number of objects at once. Replaced client.filter call in Jinja2 based computed attributes.

0 commit comments

Comments
 (0)