Skip to content

Commit 93c4695

Browse files
authored
Merge pull request #4823 from opsmill/dga-20241030-task-api
Integrate Prefect Flow data into existing InfrahubTask query
2 parents c2a8bba + 6809f8e commit 93c4695

File tree

6 files changed

+349
-22
lines changed

6 files changed

+349
-22
lines changed

backend/infrahub/core/task/task.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
1-
from typing import Any, Optional
1+
from collections import defaultdict
2+
from typing import TYPE_CHECKING, Any, Optional
23

4+
from prefect.client.orchestration import get_client
5+
from prefect.client.schemas.filters import (
6+
FlowRunFilter,
7+
FlowRunFilterTags,
8+
LogFilter,
9+
LogFilterFlowRunId,
10+
)
11+
from prefect.client.schemas.sorting import (
12+
FlowRunSort,
13+
)
314
from pydantic import ConfigDict, Field
415

516
from infrahub.core.constants import TaskConclusion
@@ -10,9 +21,15 @@
1021
from infrahub.core.timestamp import current_timestamp
1122
from infrahub.database import InfrahubDatabase
1223
from infrahub.utils import get_nested_dict
24+
from infrahub.workflows.constants import TAG_NAMESPACE, WorkflowTag
1325

1426
from .task_log import TaskLog
1527

28+
if TYPE_CHECKING:
29+
from prefect.client.schemas.objects import Log as PrefectLog
30+
31+
LOG_LEVEL_MAPPING = {10: "debug", 20: "info", 30: "warning", 40: "error", 50: "critical"}
32+
1633

1734
class Task(StandardNode):
1835
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -91,3 +108,83 @@ async def query(
91108
)
92109

93110
return {"count": count, "edges": nodes}
111+
112+
113+
class NewTask:
114+
@classmethod
115+
async def query(
116+
cls,
117+
fields: dict[str, Any],
118+
related_nodes: list[str],
119+
branch: str | None = None,
120+
limit: int | None = None,
121+
offset: int = 0,
122+
) -> dict[str, Any]:
123+
nodes: list[dict] = []
124+
count = None
125+
126+
log_fields = get_nested_dict(nested_dict=fields, keys=["edges", "node", "logs", "edges", "node"])
127+
logs_flow: dict[str, list[PrefectLog]] = defaultdict(list)
128+
129+
async with get_client(sync_client=False) as client:
130+
tags = [TAG_NAMESPACE]
131+
132+
if branch:
133+
tags.append(WorkflowTag.BRANCH.render(identifier=branch))
134+
135+
# We only support one related node for now, need to investigate HOW (and IF) we can support more
136+
if related_nodes:
137+
tags.append(WorkflowTag.RELATED_NODE.render(identifier=related_nodes[0]))
138+
139+
flow_run_filters = FlowRunFilter(
140+
tags=FlowRunFilterTags(all_=tags),
141+
)
142+
143+
flows = await client.read_flow_runs(
144+
flow_run_filter=flow_run_filters,
145+
limit=limit,
146+
offset=offset,
147+
sort=FlowRunSort.START_TIME_DESC,
148+
)
149+
150+
# For now count will just return the number of objects in the response
151+
# it won't work well with pagination but it doesn't look like Prefect provide a good option to count all flows
152+
if "count" in fields:
153+
count = len(flows)
154+
155+
if log_fields:
156+
flow_ids = [flow.id for flow in flows]
157+
all_logs = await client.read_logs(log_filter=LogFilter(flow_run_id=LogFilterFlowRunId(any_=flow_ids)))
158+
for log in all_logs:
159+
logs_flow[log.flow_run_id].append(log)
160+
161+
for flow in flows:
162+
logs = []
163+
if log_fields:
164+
logs = [
165+
{
166+
"node": {
167+
"message": log.message,
168+
"severity": LOG_LEVEL_MAPPING.get(log.level, "error"),
169+
"timestamp": log.timestamp.to_iso8601_string(),
170+
}
171+
}
172+
for log in logs_flow[flow.id]
173+
]
174+
175+
nodes.append(
176+
{
177+
"node": {
178+
"title": flow.name,
179+
"conclusion": flow.state_name,
180+
"related_node": "",
181+
"related_node_kind": "",
182+
"created_at": flow.created.to_iso8601_string(),
183+
"updated_at": flow.updated.to_iso8601_string(),
184+
"id": flow.id,
185+
"logs": {"edges": logs},
186+
}
187+
}
188+
)
189+
190+
return {"count": count or 0, "edges": nodes}

backend/infrahub/graphql/queries/task.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Optional
3+
from typing import TYPE_CHECKING, Any
44

55
from graphene import Field, Int, List, ObjectType, String
66
from infrahub_sdk.utils import extract_fields_first_node
77

8-
from infrahub.core.task import Task as TaskNode
9-
from infrahub.graphql.types import TaskNodes
8+
from infrahub.core.task.task import NewTask as TaskNewNode
9+
from infrahub.core.task.task import Task as TaskNode
10+
from infrahub.graphql.types.task import TaskNodes
1011

1112
if TYPE_CHECKING:
1213
from graphql import GraphQLResolveInfo
@@ -24,30 +25,53 @@ async def resolve(
2425
info: GraphQLResolveInfo,
2526
limit: int = 10,
2627
offset: int = 0,
27-
ids: Optional[list] = None,
28-
related_node__ids: Optional[list] = None,
28+
ids: list | None = None,
29+
branch: str | None = None,
30+
related_node__ids: list | None = None,
2931
) -> dict[str, Any]:
3032
related_nodes = related_node__ids or []
3133
ids = ids or []
32-
return await Tasks.query(info=info, limit=limit, offset=offset, ids=ids, related_nodes=related_nodes)
34+
return await Tasks.query(
35+
info=info, branch=branch, limit=limit, offset=offset, ids=ids, related_nodes=related_nodes
36+
)
3337

3438
@classmethod
3539
async def query(
36-
cls, info: GraphQLResolveInfo, limit: int, offset: int, related_nodes: list[str], ids: list[str]
40+
cls,
41+
info: GraphQLResolveInfo,
42+
limit: int,
43+
offset: int,
44+
related_nodes: list[str],
45+
ids: list[str],
46+
branch: str | None,
3747
) -> dict[str, Any]:
3848
context: GraphqlContext = info.context
3949
fields = await extract_fields_first_node(info)
4050

41-
return await TaskNode.query(
42-
db=context.db, fields=fields, limit=limit, offset=offset, ids=ids, related_nodes=related_nodes
51+
# During the migration, query both Prefect and Infrahub to get the list of tasks
52+
if not branch:
53+
infrahub_tasks = await TaskNode.query(
54+
db=context.db, fields=fields, limit=limit, offset=offset, ids=ids, related_nodes=related_nodes
55+
)
56+
else:
57+
infrahub_tasks = {}
58+
59+
prefect_tasks = await TaskNewNode.query(
60+
fields=fields, branch=branch, related_nodes=related_nodes, limit=limit, offset=offset
4361
)
4462

63+
return {
64+
"count": infrahub_tasks.get("count", 0) + prefect_tasks.get("count", 0),
65+
"edges": infrahub_tasks.get("edges", []) + prefect_tasks.get("edges", []),
66+
}
67+
4568

4669
Task = Field(
4770
Tasks,
4871
resolver=Tasks.resolve,
4972
limit=Int(required=False),
5073
offset=Int(required=False),
5174
related_node__ids=List(String),
75+
branch=String(required=False),
5276
ids=List(String),
5377
)

backend/infrahub/graphql/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@
5050
InfrahubSearchAnywhere,
5151
InfrahubStatus,
5252
Relationship,
53-
Task,
5453
)
5554
from .queries.diff.tree import DiffTreeQuery, DiffTreeSummaryQuery
55+
from .queries.task import Task
5656

5757

5858
class InfrahubBaseQuery(ObjectType):

backend/infrahub/workflows/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class WorkflowTag(InfrahubStringEnum):
1313
BRANCH = "branch/{identifier}"
1414
WORKFLOWTYPE = "workflow-type/{identifier}"
1515
DATABASE_CHANGE = "database-change"
16+
RELATED_NODE = "node/{identifier}"
1617

1718
def render(self, identifier: str | None = None) -> str:
1819
if identifier is None:

backend/infrahub/workflows/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from infrahub import __version__
1313
from infrahub.core.constants import BranchSupportType
1414

15-
from .constants import WorkflowTag, WorkflowType
15+
from .constants import TAG_NAMESPACE, WorkflowTag, WorkflowType
1616

1717
TASK_RESULT_STORAGE_NAME = "infrahub-storage"
1818

@@ -61,7 +61,7 @@ def to_deployment(self) -> dict[str, Any]:
6161
return payload
6262

6363
def get_tags(self) -> list[str]:
64-
tags: list[str] = [WorkflowTag.WORKFLOWTYPE.render(identifier=self.type.value)]
64+
tags: list[str] = [TAG_NAMESPACE, WorkflowTag.WORKFLOWTYPE.render(identifier=self.type.value)]
6565
tags += [tag.render() for tag in self.tags]
6666
return tags
6767

0 commit comments

Comments
 (0)