Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion backend/infrahub/computed_attribute/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING

from prefect.events.schemas.automations import Automation
from prefect.events.schemas.automations import Automation # noqa: TCH002
from pydantic import BaseModel, Field
from typing_extensions import Self

from .constants import AUTOMATION_NAME_PREFIX

if TYPE_CHECKING:
from infrahub.core.schema.schema_branch_computed import PythonDefinition


class ComputedAttributeAutomations(BaseModel):
data: dict[str, dict[str, Automation]] = Field(default_factory=lambda: defaultdict(dict))
Expand Down Expand Up @@ -37,3 +44,14 @@ def has(self, identifier: str, scope: str) -> bool:
if identifier in self.data and scope in self.data[identifier]:
return True
return False


@dataclass
class PythonTransformComputedAttribute:
name: str
repository_id: str
repository_name: str
repository_kind: str
query_name: str
query_models: list[str]
computed_attribute: PythonDefinition
116 changes: 115 additions & 1 deletion backend/infrahub/computed_attribute/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
from infrahub.git.repository import get_initialized_repo
from infrahub.services import services
from infrahub.support.macro import MacroDefinition
from infrahub.tasks.registry import refresh_branches
from infrahub.workflows.catalogue import PROCESS_COMPUTED_MACRO, UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM
from infrahub.workflows.utils import add_branch_tag

from .constants import AUTOMATION_NAME, AUTOMATION_NAME_PREFIX
from .models import ComputedAttributeAutomations
from .models import ComputedAttributeAutomations, PythonTransformComputedAttribute

if TYPE_CHECKING:
from infrahub.core.schema.computed_attribute import ComputedAttribute
Expand All @@ -48,6 +49,8 @@ async def process_transform(
branch_name: str,
node_kind: str,
object_id: str,
computed_attribute_name: str, # pylint: disable=unused-argument
computed_attribute_kind: str, # pylint: disable=unused-argument
updated_fields: list[str] | None = None, # pylint: disable=unused-argument
) -> None:
"""Request to the creation of git branches in available repositories."""
Expand Down Expand Up @@ -262,3 +265,114 @@ async def computed_attribute_setup() -> None:
else:
await client.create_automation(automation=automation)
log.info(f"{computed_attribute.key_name} Created")


@flow(
name="computed-attribute-setup-python",
flow_run_name="Setup computed attributes for Python transforms in task-manager",
)
async def computed_attribute_setup_python() -> None:
service = services.service
async with service.database.start_session() as db:
await refresh_branches(db=db)

schema_branch = registry.schema.get_schema_branch(name=registry.default_branch)
log = get_run_logger()

transform_attributes = schema_branch.computed_attributes.python_attributes_by_transform

transform_names = list(transform_attributes.keys())

transforms = await service.client.filters(
kind="CoreTransformPython",
branch=registry.default_branch,
prefetch_relationships=True,
populate_store=True,
name__values=transform_names,
)

found_transforms_names = [transform.name.value for transform in transforms]
for transform_name in transform_names:
if transform_name not in found_transforms_names:
log.warning(
msg=f"The transform {transform_name} is assigned to a computed attribute but the transform could not be found in the database."
)

computed_attributes: list[PythonTransformComputedAttribute] = []
for transform in transforms:
for attribute in transform_attributes[transform.name.value]:
computed_attributes.append(
PythonTransformComputedAttribute(
name=transform.name.value,
repository_id=transform.repository.peer.id,
repository_name=transform.repository.peer.name.value,
repository_kind=transform.repository.peer.typename,
query_name=transform.query.peer.name.value,
query_models=transform.query.peer.models.value,
computed_attribute=attribute,
)
)

async with get_client(sync_client=False) as client:
deployments = {
item.name: item
for item in await client.read_deployments(
deployment_filter=DeploymentFilter(
name=DeploymentFilterName(any_=[UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM.name])
)
)
}
if UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM.name not in deployments:
raise ValueError("Unable to find the deployment for UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM")

deployment_id_python = deployments[UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM.name].id

automations = await client.read_automations()
existing_computed_attr_automations = ComputedAttributeAutomations.from_prefect(automations=automations)

for computed_attribute in computed_attributes:
log.info(f"processing {computed_attribute.computed_attribute.key_name}")
scope = "default"

automation = AutomationCore(
name=AUTOMATION_NAME.format(
prefix=AUTOMATION_NAME_PREFIX,
identifier=computed_attribute.computed_attribute.key_name,
scope=scope,
),
description=f"Process value of the computed attribute for {computed_attribute.computed_attribute.key_name} [{scope}]",
enabled=True,
trigger=EventTrigger(
posture=Posture.Reactive,
expect={"infrahub.node.*"},
within=timedelta(0),
match=ResourceSpecification({"infrahub.node.kind": [computed_attribute.computed_attribute.kind]}),
threshold=1,
),
actions=[
RunDeployment(
source="selected",
deployment_id=deployment_id_python,
parameters={
"branch_name": "{{ event.resource['infrahub.branch.name'] }}",
"node_kind": "{{ event.resource['infrahub.node.kind'] }}",
"object_id": "{{ event.resource['infrahub.node.id'] }}",
"computed_attribute_name": computed_attribute.computed_attribute.attribute.name,
"computed_attribute_kind": computed_attribute.computed_attribute.kind,
},
job_variables={},
)
],
)

if existing_computed_attr_automations.has(
identifier=computed_attribute.computed_attribute.key_name, scope=scope
):
existing = existing_computed_attr_automations.get(
identifier=computed_attribute.computed_attribute.key_name, scope=scope
)
await client.update_automation(automation_id=existing.id, automation=automation)
log.info(f"{computed_attribute.computed_attribute.key_name} Updated")
else:
await client.create_automation(automation=automation)
log.info(f"{computed_attribute.computed_attribute.key_name} Created")
8 changes: 7 additions & 1 deletion backend/infrahub/core/schema/schema_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from infrahub.visuals import select_color

from .constants import INTERNAL_SCHEMA_NODE_KINDS, SchemaNamespace
from .schema_branch_computed import ComputedAttributes

log = get_logger()

Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(self, cache: dict, name: str | None = None, data: dict[str, dict[st
self.nodes: dict[str, str] = {}
self.generics: dict[str, str] = {}
self.profiles: dict[str, str] = {}
self.computed_attributes = ComputedAttributes()
self._computed_jinja2_attribute_map: dict[str, RegisteredNodeComputedAttribute] = {}

if data:
Expand Down Expand Up @@ -949,6 +951,7 @@ def validate_kinds(self) -> None:
) from None

def validate_computed_attributes(self) -> None:
self.computed_attributes = ComputedAttributes()
self._computed_jinja2_attribute_map = {}
for name in self.nodes.keys():
node_schema = self.get_node(name=name, duplicate=False)
Expand Down Expand Up @@ -1016,11 +1019,14 @@ def _validate_computed_attribute(self, node: NodeSchema, attribute: AttributeSch

self._register_computed_attribute_target(node=node, attribute=attribute, schema_path=schema_path)

if attribute.computed_attribute.kind == ComputedAttributeKind.TRANSFORM_PYTHON and not attribute.optional:
elif attribute.computed_attribute.kind == ComputedAttributeKind.TRANSFORM_PYTHON and not attribute.optional:
raise ValueError(
f"{node.kind}: Attribute {attribute.name!r} is a computed transform, it can't be mandatory"
)

elif attribute.computed_attribute.kind == ComputedAttributeKind.TRANSFORM_PYTHON:
self.computed_attributes.add_python_attribute(node=node, attribute=attribute)

def get_impacted_macros(self, kind: str, updates: list[str] | None = None) -> list[ComputedAttributeTarget]:
if mapping := self._computed_jinja2_attribute_map.get(kind):
return mapping.get_targets(updates=updates)
Expand Down
45 changes: 45 additions & 0 deletions backend/infrahub/core/schema/schema_branch_computed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from dataclasses import dataclass

from infrahub.core.schema import (
AttributeSchema,
NodeSchema,
)


@dataclass
class PythonDefinition:
kind: str
attribute: AttributeSchema

@property
def key_name(self) -> str:
return f"{self.kind}_{self.attribute.name}"


class ComputedAttributes:
def __init__(self) -> None:
self._computed_python_transform_attribute_map: dict[str, list[AttributeSchema]] = {}

def add_python_attribute(self, node: NodeSchema, attribute: AttributeSchema) -> None:
if node.kind not in self._computed_python_transform_attribute_map:
self._computed_python_transform_attribute_map[node.kind] = []
self._computed_python_transform_attribute_map[node.kind].append(attribute)

def get_kinds_python_attributes(self) -> list[str]:
"""Return kinds that have Python attributes defined"""
return list(self._computed_python_transform_attribute_map.keys())

@property
def python_attributes_by_transform(self) -> dict[str, list[PythonDefinition]]:
computed_attributes: dict[str, list[PythonDefinition]] = {}
for kind, attributes in self._computed_python_transform_attribute_map.items():
for attribute in attributes:
if attribute.computed_attribute and attribute.computed_attribute.transform:
if attribute.computed_attribute.transform not in computed_attributes:
computed_attributes[attribute.computed_attribute.transform] = []

computed_attributes[attribute.computed_attribute.transform].append(
PythonDefinition(kind=kind, attribute=attribute)
)

return computed_attributes
35 changes: 0 additions & 35 deletions backend/infrahub/graphql/mutations/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from infrahub.exceptions import ValidationError
from infrahub.log import get_log_data, get_logger
from infrahub.worker import WORKER_IDENTITY
from infrahub.workflows.catalogue import UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM

from .node_getter.by_default_filter import MutationNodeGetterByDefaultFilter
from .node_getter.by_hfid import MutationNodeGetterByHfid
Expand Down Expand Up @@ -101,8 +100,6 @@ async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: InputObjectTyp
# Reset the time of the query to guarantee that all resolvers executed after this point will account for the changes
context.at = Timestamp()

# Get relevant macros based on the current change
# macros = schema_branch.get_impacted_macros(kind=obj.get_kind(), updates=updated_fields)
if config.SETTINGS.broker.enable and context.background:
log_data = get_log_data()
request_id = log_data.get("request_id", "")
Expand All @@ -120,38 +117,6 @@ async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: InputObjectTyp

context.background.add_task(context.service.event.send, event)

# Temporary workaround until there's a proper deployment for Transforms during schema load / repo sync
schema_branch = registry.schema.get_schema_branch(name=context.branch.name)
node_schema = schema_branch.get(obj._schema.kind)
if isinstance(node_schema, NodeSchema):
has_computed_attributes = [
attribute for attribute in node_schema.attributes if attribute.computed_attribute
]
if has_computed_attributes:
updated_fields = list(data.keys())

await context.service.workflow.submit_workflow(
workflow=UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM,
parameters={
"branch_name": context.branch.name,
"node_kind": obj.get_kind(),
"object_id": obj.get_id(),
"updated_fields": updated_fields,
},
)

# # Add event
# if macros:
# await context.service.workflow.submit_workflow(
# workflow=PROCESS_COMPUTED_MACRO,
# parameters={
# "branch_name": context.branch.name,
# "node_kind": obj.get_kind(),
# "object_id": obj.get_id(),
# "updated_fields": updated_fields,
# },
# )

return mutation

@classmethod
Expand Down
15 changes: 10 additions & 5 deletions backend/infrahub/schema/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from prefect.events.schemas.automations import EventTrigger, Posture
from prefect.logging import get_run_logger

from infrahub.workflows.catalogue import COMPUTED_ATTRIBUTE_SETUP
from infrahub.workflows.catalogue import COMPUTED_ATTRIBUTE_SETUP, COMPUTED_ATTRIBUTE_SETUP_PYTHON

from .constants import AUTOMATION_NAME

Expand All @@ -25,9 +25,7 @@ async def schema_updated_setup() -> None:
for item in await client.read_deployments(
deployment_filter=DeploymentFilter(
name=DeploymentFilterName(
any_=[
COMPUTED_ATTRIBUTE_SETUP.name,
]
any_=[COMPUTED_ATTRIBUTE_SETUP.name, COMPUTED_ATTRIBUTE_SETUP_PYTHON.name]
)
)
)
Expand All @@ -36,6 +34,7 @@ async def schema_updated_setup() -> None:
raise ValueError("Unable to find the deployment for PROCESS_COMPUTED_MACRO")

deployment_id_computed_attribute_setup = deployments[COMPUTED_ATTRIBUTE_SETUP.name].id
deployment_id_computed_attribute_setup_python = deployments[COMPUTED_ATTRIBUTE_SETUP_PYTHON.name].id

schema_update_automation = await client.find_automation(id_or_name=AUTOMATION_NAME)

Expand All @@ -55,7 +54,13 @@ async def schema_updated_setup() -> None:
deployment_id=deployment_id_computed_attribute_setup,
parameters={},
job_variables={},
)
),
RunDeployment(
source="selected",
deployment_id=deployment_id_computed_attribute_setup_python,
parameters={},
job_variables={},
),
],
)

Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def client(self) -> InfrahubClient:

return self._client

def set_client(self, client: InfrahubClient) -> None:
def set_client(self, client: InfrahubClient | None) -> None:
self._client = client

@property
Expand Down
9 changes: 9 additions & 0 deletions backend/infrahub/workflows/catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,14 @@
function="computed_attribute_setup",
)

COMPUTED_ATTRIBUTE_SETUP_PYTHON = WorkflowDefinition(
name="computed-attribute-setup-python",
type=WorkflowType.INTERNAL,
module="infrahub.computed_attribute.tasks",
function="computed_attribute_setup_python",
)


UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM = WorkflowDefinition(
name="process_computed_attribute_transform",
type=WorkflowType.INTERNAL,
Expand Down Expand Up @@ -289,6 +297,7 @@
GIT_REPOSITORY_ADD_READ_ONLY,
PROCESS_COMPUTED_MACRO,
COMPUTED_ATTRIBUTE_SETUP,
COMPUTED_ATTRIBUTE_SETUP_PYTHON,
UPDATE_COMPUTED_ATTRIBUTE_TRANSFORM,
REQUEST_PROPOSED_CHANGE_DATA_INTEGRITY,
SCHEMA_UPDATED_SETUP,
Expand Down
Loading
Loading