Skip to content

Commit 353abd9

Browse files
authored
Merge pull request #4833 from opsmill/dga-202241102-task-ongoing
Cont working on InfrahubTask
2 parents 946c71a + 7b5be6d commit 353abd9

File tree

9 files changed

+634
-151
lines changed

9 files changed

+634
-151
lines changed

backend/infrahub/core/task/task.py

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,5 @@
1-
from collections import defaultdict
2-
from typing import TYPE_CHECKING, Any, Optional
1+
from typing import Any, Optional
32

4-
from prefect.client.orchestration import get_client
5-
from prefect.client.schemas.filters import (
6-
FlowRunFilter,
7-
FlowRunFilterTags,
8-
LogFilter,
9-
LogFilterFlowRunId,
10-
)
11-
from prefect.client.schemas.sorting import (
12-
FlowRunSort,
13-
)
143
from pydantic import ConfigDict, Field
154

165
from infrahub.core.constants import TaskConclusion
@@ -21,15 +10,9 @@
2110
from infrahub.core.timestamp import current_timestamp
2211
from infrahub.database import InfrahubDatabase
2312
from infrahub.utils import get_nested_dict
24-
from infrahub.workflows.constants import TAG_NAMESPACE, WorkflowTag
2513

2614
from .task_log import TaskLog
2715

28-
if TYPE_CHECKING:
29-
from prefect.client.schemas.objects import Log as PrefectLog
30-
31-
LOG_LEVEL_MAPPING = {10: "debug", 20: "info", 30: "warning", 40: "error", 50: "critical"}
32-
3316

3417
class Task(StandardNode):
3518
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -55,11 +38,13 @@ async def query(
5538
cls,
5639
db: InfrahubDatabase,
5740
fields: dict[str, Any],
58-
limit: int,
59-
offset: int,
60-
ids: list[str],
61-
related_nodes: list[str],
41+
limit: int | None = None,
42+
offset: int | None = None,
43+
ids: list[str] | None = None,
44+
related_nodes: list[str] | None = None,
6245
) -> dict[str, Any]:
46+
ids = ids or []
47+
related_nodes = related_nodes or []
6348
log_fields = get_nested_dict(nested_dict=fields, keys=["edges", "node", "logs", "edges", "node"])
6449
count = None
6550
if "count" in fields:
@@ -108,83 +93,3 @@ async def query(
10893
)
10994

11095
return {"count": count, "edges": nodes}
111-
112-
113-
class NewTask:
114-
@classmethod
115-
async def query(
116-
cls,
117-
fields: dict[str, Any],
118-
related_nodes: list[str],
119-
branch: str | None = None,
120-
limit: int | None = None,
121-
offset: int = 0,
122-
) -> dict[str, Any]:
123-
nodes: list[dict] = []
124-
count = None
125-
126-
log_fields = get_nested_dict(nested_dict=fields, keys=["edges", "node", "logs", "edges", "node"])
127-
logs_flow: dict[str, list[PrefectLog]] = defaultdict(list)
128-
129-
async with get_client(sync_client=False) as client:
130-
tags = [TAG_NAMESPACE]
131-
132-
if branch:
133-
tags.append(WorkflowTag.BRANCH.render(identifier=branch))
134-
135-
# We only support one related node for now, need to investigate HOW (and IF) we can support more
136-
if related_nodes:
137-
tags.append(WorkflowTag.RELATED_NODE.render(identifier=related_nodes[0]))
138-
139-
flow_run_filters = FlowRunFilter(
140-
tags=FlowRunFilterTags(all_=tags),
141-
)
142-
143-
flows = await client.read_flow_runs(
144-
flow_run_filter=flow_run_filters,
145-
limit=limit,
146-
offset=offset,
147-
sort=FlowRunSort.START_TIME_DESC,
148-
)
149-
150-
# For now count will just return the number of objects in the response
151-
# it won't work well with pagination but it doesn't look like Prefect provide a good option to count all flows
152-
if "count" in fields:
153-
count = len(flows)
154-
155-
if log_fields:
156-
flow_ids = [flow.id for flow in flows]
157-
all_logs = await client.read_logs(log_filter=LogFilter(flow_run_id=LogFilterFlowRunId(any_=flow_ids)))
158-
for log in all_logs:
159-
logs_flow[log.flow_run_id].append(log)
160-
161-
for flow in flows:
162-
logs = []
163-
if log_fields:
164-
logs = [
165-
{
166-
"node": {
167-
"message": log.message,
168-
"severity": LOG_LEVEL_MAPPING.get(log.level, "error"),
169-
"timestamp": log.timestamp.to_iso8601_string(),
170-
}
171-
}
172-
for log in logs_flow[flow.id]
173-
]
174-
175-
nodes.append(
176-
{
177-
"node": {
178-
"title": flow.name,
179-
"conclusion": flow.state_name,
180-
"related_node": "",
181-
"related_node_kind": "",
182-
"created_at": flow.created.to_iso8601_string(),
183-
"updated_at": flow.updated.to_iso8601_string(),
184-
"id": flow.id,
185-
"logs": {"edges": logs},
186-
}
187-
}
188-
)
189-
190-
return {"count": count or 0, "edges": nodes}

backend/infrahub/graphql/queries/task.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
from graphene import Field, Int, List, ObjectType, String
66
from infrahub_sdk.utils import extract_fields_first_node
7+
from prefect.client.schemas.objects import StateType
78

8-
from infrahub.core.task.task import NewTask as TaskNewNode
99
from infrahub.core.task.task import Task as TaskNode
1010
from infrahub.graphql.types.task import TaskNodes
11+
from infrahub.task_manager.task import PrefectTask
12+
from infrahub.workflows.constants import WorkflowTag
1113

1214
if TYPE_CHECKING:
1315
from graphql import GraphQLResolveInfo
@@ -35,15 +37,28 @@ async def resolve(
3537
info=info, branch=branch, limit=limit, offset=offset, ids=ids, related_nodes=related_nodes
3638
)
3739

40+
@staticmethod
41+
async def resolve_branch_status(
42+
root: dict, # pylint: disable=unused-argument
43+
info: GraphQLResolveInfo,
44+
branch: str,
45+
) -> dict[str, Any]:
46+
statuses: list[StateType] = [StateType.PENDING, StateType.RUNNING, StateType.CANCELLING, StateType.SCHEDULED]
47+
tags: list[str] = [WorkflowTag.DATABASE_CHANGE.render()]
48+
49+
return await Tasks.query(info=info, branch=branch, statuses=statuses, tags=tags)
50+
3851
@classmethod
3952
async def query(
4053
cls,
4154
info: GraphQLResolveInfo,
42-
limit: int,
43-
offset: int,
44-
related_nodes: list[str],
45-
ids: list[str],
46-
branch: str | None,
55+
related_nodes: list[str] | None = None,
56+
ids: list[str] | None = None,
57+
statuses: list[StateType] | None = None,
58+
tags: list[str] | None = None,
59+
branch: str | None = None,
60+
limit: int | None = None,
61+
offset: int | None = None,
4762
) -> dict[str, Any]:
4863
context: GraphqlContext = info.context
4964
fields = await extract_fields_first_node(info)
@@ -56,12 +71,20 @@ async def query(
5671
else:
5772
infrahub_tasks = {}
5873

59-
prefect_tasks = await TaskNewNode.query(
60-
fields=fields, branch=branch, related_nodes=related_nodes, limit=limit, offset=offset
74+
prefect_tasks = await PrefectTask.query(
75+
db=context.db,
76+
fields=fields,
77+
branch=branch,
78+
statuses=statuses,
79+
tags=tags,
80+
related_nodes=related_nodes,
81+
limit=limit,
82+
offset=offset,
6183
)
62-
84+
infrahub_count = infrahub_tasks.get("count", None)
85+
prefect_count = prefect_tasks.get("count", None)
6386
return {
64-
"count": infrahub_tasks.get("count", 0) + prefect_tasks.get("count", 0),
87+
"count": (infrahub_count or 0) + (prefect_count or 0),
6588
"edges": infrahub_tasks.get("edges", []) + prefect_tasks.get("edges", []),
6689
}
6790

@@ -75,3 +98,10 @@ async def query(
7598
branch=String(required=False),
7699
ids=List(String),
77100
)
101+
102+
TaskBranchStatus = Field(
103+
Tasks,
104+
resolver=Tasks.resolve_branch_status,
105+
branch=String(required=False),
106+
description="Return the list of all pending or running tasks that can modify the data, for a given branch",
107+
)

backend/infrahub/graphql/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
Relationship,
5353
)
5454
from .queries.diff.tree import DiffTreeQuery, DiffTreeSummaryQuery
55-
from .queries.task import Task
55+
from .queries.task import Task, TaskBranchStatus
5656

5757

5858
class InfrahubBaseQuery(ObjectType):
@@ -72,6 +72,7 @@ class InfrahubBaseQuery(ObjectType):
7272
InfrahubSearchAnywhere = InfrahubSearchAnywhere
7373

7474
InfrahubTask = Task
75+
InfrahubTaskBranchStatus = TaskBranchStatus
7576

7677
IPAddressGetNextAvailable = InfrahubIPAddressGetNextAvailable
7778
IPPrefixGetNextAvailable = InfrahubIPPrefixGetNextAvailable

backend/infrahub/graphql/types/task.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
11
from __future__ import annotations
22

3-
from graphene import Field, ObjectType, String
3+
from graphene import Enum, Field, Float, List, ObjectType, String
4+
from graphene.types.generic import GenericScalar
5+
from prefect.client.schemas.objects import StateType
46

57
from .task_log import TaskLogEdge
68

9+
TaskState = Enum.from_enum(StateType)
10+
711

812
class Task(ObjectType):
913
id = String(required=True)
1014
title = String(required=True)
1115
conclusion = String(required=True)
16+
state = TaskState(required=False)
17+
progress = Float(required=False)
18+
branch = String(required=False)
1219
created_at = String(required=True)
1320
updated_at = String(required=True)
21+
parameters = GenericScalar(required=False)
22+
tags = List(String, required=False)
23+
start_time = String(required=False)
1424

1525

1626
class TaskNode(Task):
17-
related_node = String(required=True)
18-
related_node_kind = String(required=True)
27+
related_node = String(required=False)
28+
related_node_kind = String(required=False)
1929
logs = Field(TaskLogEdge)
2030

2131

backend/infrahub/task_manager/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from infrahub.core.constants import TaskConclusion
2+
3+
LOG_LEVEL_MAPPING = {10: "debug", 20: "info", 30: "warning", 40: "error", 50: "critical"}
4+
5+
CONCLUSION_STATE_MAPPING = {
6+
"Scheduled": TaskConclusion.UNKNOWN,
7+
"Pending": TaskConclusion.UNKNOWN,
8+
"Running": TaskConclusion.UNKNOWN,
9+
"Completed": TaskConclusion.SUCCESS,
10+
"Failed": TaskConclusion.FAILURE,
11+
"Cancelled": TaskConclusion.FAILURE,
12+
"Crashed": TaskConclusion.FAILURE,
13+
"Paused": TaskConclusion.UNKNOWN,
14+
"Cancelling": TaskConclusion.FAILURE,
15+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from collections import defaultdict
2+
from typing import DefaultDict
3+
from uuid import UUID
4+
5+
from prefect.client.schemas.objects import Log as PrefectLog
6+
from pydantic import BaseModel, Field
7+
8+
from .constants import LOG_LEVEL_MAPPING
9+
10+
11+
class RelatedNodesInfo(BaseModel):
12+
id: dict[UUID, str] = Field(default_factory=dict)
13+
kind: dict[UUID, str | None] = Field(default_factory=dict)
14+
15+
def get_unique_related_node_ids(self) -> list[str]:
16+
return list(set(list(self.id.values())))
17+
18+
19+
class FlowLogs(BaseModel):
20+
logs: DefaultDict[UUID, list[PrefectLog]] = Field(default_factory=lambda: defaultdict(list))
21+
22+
def to_graphql(self, flow_id: UUID) -> list[dict]:
23+
return [
24+
{
25+
"node": {
26+
"message": log.message,
27+
"severity": LOG_LEVEL_MAPPING.get(log.level, "error"),
28+
"timestamp": log.timestamp.to_iso8601_string(),
29+
}
30+
}
31+
for log in self.logs[flow_id]
32+
]
33+
34+
35+
class FlowProgress(BaseModel):
36+
data: dict[UUID, float] = Field(default_factory=dict)

0 commit comments

Comments
 (0)